{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ENR6EQ2b8R8"
      },
      "outputs": [],
      "source": [
        "!pip uninstall -y datasets\n",
        "!pip install datasets==2.18.0\n",
        "!pip install evaluate\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Study using large step size\n",
        "\n",
        "\n",
        "# Prune Layers Based on Entropy Rate Using MNLI Dataset\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8)**2\n",
        "                cos_squares.append(c2.item())\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune] if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "\n",
        "\n",
        "def preprocess_function(examples, tok, max_length=128):\n",
        "    return tok(examples['premise'],\n",
        "               examples['hypothesis'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\"roberta-base\", num_labels=3).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx] for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MNLI Acc: {acc:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "    opt = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad = False\n",
        "    for p in model.classifier.parameters(): p.requires_grad = True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad = True\n",
        "        l.B.requires_grad = True\n",
        "    opt = torch.optim.Adam(\n",
        "        list(model.classifier.parameters()) + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt, 0, len(train_loader)*3)\n",
        "    scaler = GradScaler()\n",
        "    for epoch in range(3):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device), attention_mask=b['attention_mask'].to(device), labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"mnli\", split=\"train[:5000]\").shuffle(seed)\n",
        "    dev_ds = load_dataset(\"glue\", \"mnli\", split=\"validation_matched[:1000]\")\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"premise\",\"hypothesis\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "\n",
        "\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "zCT1QIL4N80Q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# MRPC\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )**2  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] MRPC Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # load & preprocess MRPC subset\n",
        "    train_ds = load_dataset(\"glue\", \"mrpc\", split=\"train\")\\\n",
        "               .shuffle(seed).select(range(1000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"mrpc\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                         batched=True,\n",
        "                         remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                    .rename_column(\"label\",\"labels\")\n",
        "    dev   = dev_ds.map(lambda ex: preprocess_function(ex, tokenizer),\n",
        "                       batched=True,\n",
        "                       remove_columns=[\"sentence1\",\"sentence2\",\"idx\"])\\\n",
        "                  .rename_column(\"label\",\"labels\")\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader,\n",
        "                                       dev_loader,\n",
        "                                       device)\n",
        "   # model = prune_and_finetuning(model,\n",
        "    #                             train_loader,\n",
        "     #                            dev_loader,\n",
        "      #                           device,\n",
        "       #                          er_scores)\n",
        "   # lora_only_finetuning(model,\n",
        "        #                 train_loader,\n",
        "         #                dev_loader,\n",
        "          #               device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "3wlsElsZjjzU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# SST-2\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )**2  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune SST2 Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] SST2 Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] SST2 Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "# load & preprocess SST-2 subset\n",
        "    train_ds = load_dataset(\"glue\", \"sst2\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"sst2\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                         batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence\", \"idx\"])\n",
        "    dev = dev_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                         batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence\", \"idx\"])\n",
        "\n",
        "\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader,\n",
        "                                       dev_loader,\n",
        "                                       device)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "SOnZT6CA8O1h"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# CoLA\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )**2  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune CoLA Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] CoLA Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] CoLA Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "\n",
        "      # load & preprocess CoLA subset\n",
        "    train_ds = load_dataset(\"glue\", \"cola\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"cola\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train = train_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                         batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence\", \"idx\"])\n",
        "    dev = dev_ds.map(lambda ex: tokenizer(ex[\"sentence\"], truncation=True, padding='max_length', max_length=64),\n",
        "                     batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"sentence\", \"idx\"])\n",
        "\n",
        "\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=64)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True,\n",
        "                              collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev,   batch_size=16, shuffle=False,\n",
        "                              collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader,\n",
        "                                       dev_loader,\n",
        "                                       device)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "Lh5rJFMZkoe0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# QNLI\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )**2  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QNLI Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load & preprocess QNLI\n",
        "    from datasets import load_dataset\n",
        "    train_ds = load_dataset(\"glue\", \"qnli\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qnli\", split=\"validation\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"question\"],\n",
        "                         examples[\"sentence\"],\n",
        "                         truncation=True,\n",
        "                         padding='max_length',\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"question\", \"sentence\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"question\", \"sentence\", \"idx\"])\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer,\n",
        "                                           padding=\"max_length\",\n",
        "                                           max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    #model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    #lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "Y0s4HbYJTlfC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# QQP\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    \"\"\"\n",
        "    For each adjacent layer‐pair idx,\n",
        "    approximate the conditional entropy rate via\n",
        "      H ≈ d/2 * ln(2πe σ²) + (1/[2(B-1)]) * Σ_{i=1..B-1} cos²(ΔY_i, ΔX_i)\n",
        "    We return only the cosine‐sum term; the additive constant is the\n",
        "    same for all layers and can be dropped for pruning.\n",
        "    \"\"\"\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        # not enough history yet\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        # flatten across all non‐batch dims\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = []\n",
        "            for i in range(1, B):\n",
        "                c2 = F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                )**2  # [1]\n",
        "                cos_squares.append(c2.item())\n",
        "            # sum of cos² over i=1..B-1, then multiplied by 1/(2(B-1))\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        # shift history\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] QQP Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load & preprocess QQP\n",
        "    from datasets import load_dataset\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"qqp\", split=\"train\").shuffle(seed).select(range(5000))\n",
        "    dev_ds   = load_dataset(\"glue\", \"qqp\", split=\"validation\")\n",
        "\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"question1\"],\n",
        "                         examples[\"question2\"],\n",
        "                         truncation=True,\n",
        "                         padding=\"max_length\",\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"question1\", \"question2\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"question1\", \"question2\", \"idx\"])\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "   # model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "   # lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "SkbNUsq3Uap6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# RTE\n",
        "\n",
        "# ========================================================\n",
        "# 1) Standard imports and warning suppression\n",
        "# ========================================================\n",
        "from datasets import load_dataset\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "import numpy as np\n",
        "import random\n",
        "from collections import defaultdict\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 2) Entropy‐Rate / Hook Utilities (implements Theorem 2)\n",
        "# ========================================================\n",
        "def register_er_hooks(model):\n",
        "    \"\"\"\n",
        "    Exactly the same hooks as MIR: we watch each pair of adjacent\n",
        "    layers' output.dense activations.\n",
        "    \"\"\"\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        if X_prev.size(0) != X_curr.size(0) or Y_prev.size(0) != Y_curr.size(0):\n",
        "            # Skip batch if sizes mismatch\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(dY[i].unsqueeze(0), dX[i].unsqueeze(0), dim=1, eps=1e-8).item() ** 2\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "\n",
        "        er_scores[idx] = er\n",
        "\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "\n",
        "    return er_scores\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 3) Pruning Utilities with SkipFF (prune low‑ER)\n",
        "# ========================================================\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    # sort descending by ER → highest‐entropy layers first\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [\n",
        "        idx+1\n",
        "        for idx, _ in sorted_layers[:num_prune]\n",
        "        if idx+1 < len(model.roberta.encoder.layer)\n",
        "    ]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 4) LoRA Modules (unchanged)\n",
        "# ========================================================\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 5) Data + Eval Helpers\n",
        "# ========================================================\n",
        "def preprocess_function(examples, tok, max_length=64):\n",
        "    return tok(examples['sentence1'],\n",
        "               examples['sentence2'],\n",
        "               truncation=True,\n",
        "               padding='max_length',\n",
        "               max_length=max_length)\n",
        "\n",
        "def evaluate_model(model, dl, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"accuracy\")\n",
        "    preds, labs = [], []\n",
        "    with torch.no_grad():\n",
        "        for b in dl:\n",
        "            ids = b['input_ids'].to(device)\n",
        "            mask = b['attention_mask'].to(device)\n",
        "            labs.extend(b['labels'].cpu().numpy())\n",
        "            out = model(input_ids=ids, attention_mask=mask)\n",
        "            preds.extend(torch.argmax(out.logits, -1).cpu().numpy())\n",
        "    return metric.compute(predictions=preds, references=labs)[\"accuracy\"]\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 6) Training Stages (using ER instead of MIR)\n",
        "# ========================================================\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    print(\"=== Stage 1: Full Finetuning & ER Estimation ===\")\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=2\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            # batch‐level ER\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx]   += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        # epoch‐level ER\n",
        "        epoch_er = {idx: er_sums[idx]/er_counts[idx]\n",
        "                    for idx in er_sums if er_counts[idx] > 0}\n",
        "        print(f\"[Epoch {epoch+1}] approx Entropy Rate:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    acc = evaluate_model(model, dev_loader, device)\n",
        "    print(f\"-> Full Finetune RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    print(\"=== Stage 2: Prune (High‐ER) & Finetuning ===\")\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers (highest‐ER):\", prune_idxs)\n",
        "\n",
        "    opt   = torch.optim.Adam(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*3)\n",
        "\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(input_ids=b['input_ids'].to(device),\n",
        "                        attention_mask=b['attention_mask'].to(device),\n",
        "                        labels=b['labels'].to(device))\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[Prune FT Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "    return model\n",
        "\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device, r=2, alpha=1.0):\n",
        "    # (unchanged LoRA stage)\n",
        "    print(\"=== Stage 3: LoRA Finetuning ===\")\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model, r, alpha)\n",
        "    for p in model.roberta.parameters(): p.requires_grad=False\n",
        "    for p in model.classifier.parameters(): p.requires_grad=True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad=True\n",
        "        l.B.requires_grad=True\n",
        "\n",
        "    opt   = torch.optim.Adam(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(opt,\n",
        "                                            num_warmup_steps=0,\n",
        "                                            num_training_steps=len(train_loader)*6)\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    # If you want to continue monitoring ER during LoRA, you can re-hook here.\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for b in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(input_ids=b['input_ids'].to(device),\n",
        "                            attention_mask=b['attention_mask'].to(device),\n",
        "                            labels=b['labels'].to(device))\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "        acc = evaluate_model(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] RTE Acc: {acc:.4f}\")\n",
        "\n",
        "\n",
        "# ========================================================\n",
        "# 7) Main Entrypoint\n",
        "# ========================================================\n",
        "\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    # Load & preprocess RTE\n",
        "    from datasets import load_dataset\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "\n",
        "    train_ds = load_dataset(\"glue\", \"rte\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"rte\", split=\"validation\")\n",
        "\n",
        "    def preprocess(examples):\n",
        "        return tokenizer(examples[\"sentence1\"],\n",
        "                         examples[\"sentence2\"],\n",
        "                         truncation=True,\n",
        "                         padding=\"max_length\",\n",
        "                         max_length=128)\n",
        "\n",
        "    train = train_ds.map(preprocess, batched=True)\\\n",
        "                    .rename_column(\"label\", \"labels\")\\\n",
        "                    .remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
        "    dev = dev_ds.map(preprocess, batched=True)\\\n",
        "                .rename_column(\"label\", \"labels\")\\\n",
        "                .remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n",
        "\n",
        "    collator = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train, batch_size=8, shuffle=True, collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev, batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    #model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    #lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "d7uteGgVU6ui"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# STS-B\n",
        "\n",
        "import numpy as np\n",
        "import random\n",
        "import math\n",
        "import warnings\n",
        "\n",
        "# Monkey‑patch numpy.array to ignore the copy argument (workaround for NumPy 2.0)\n",
        "_np_array = np.array\n",
        "def _patched_array(obj, *args, copy=False, **kwargs):\n",
        "    return _np_array(obj, *args, **kwargs)\n",
        "np.array = _patched_array\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.cuda.amp import autocast, GradScaler\n",
        "from torch.utils.data import DataLoader\n",
        "import evaluate\n",
        "\n",
        "from transformers import (\n",
        "    RobertaForSequenceClassification,\n",
        "    RobertaTokenizerFast,\n",
        "    DataCollatorWithPadding,\n",
        "    get_linear_schedule_with_warmup,\n",
        ")\n",
        "from datasets import load_dataset\n",
        "\n",
        "from collections import defaultdict\n",
        "\n",
        "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
        "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n",
        "\n",
        "# ─── 1) Entropy‐Rate Hooks (Theorem 2) ─────────────────────────────────────\n",
        "def register_er_hooks(model):\n",
        "    layers = model.roberta.encoder.layer\n",
        "    activations = {\n",
        "        i: {'prev_X': None, 'prev_Y': None, 'curr_X': None, 'curr_Y': None}\n",
        "        for i in range(len(layers)-1)\n",
        "    }\n",
        "    hooks = []\n",
        "    for i in range(len(layers)-1):\n",
        "        def hook_x(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_X'] = out.detach()\n",
        "        def hook_y(module, inp, out, idx=i):\n",
        "            activations[idx]['curr_Y'] = out.detach()\n",
        "        hooks.append(layers[i].output.dense.register_forward_hook(hook_x))\n",
        "        hooks.append(layers[i+1].output.dense.register_forward_hook(hook_y))\n",
        "    return hooks, activations\n",
        "\n",
        "def remove_hooks(hooks):\n",
        "    for h in hooks:\n",
        "        h.remove()\n",
        "\n",
        "def compute_batch_entropy(activations, sigma2=1.0):\n",
        "    er_scores = {}\n",
        "    for idx, buf in activations.items():\n",
        "        X_prev, Y_prev = buf['prev_X'], buf['prev_Y']\n",
        "        X_curr, Y_curr = buf['curr_X'], buf['curr_Y']\n",
        "        # need valid previous + current and same batch size\n",
        "        if None in (X_prev, Y_prev, X_curr, Y_curr) or \\\n",
        "           X_prev.size(0) != X_curr.size(0):\n",
        "            buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "            buf['curr_X'], buf['curr_Y'] = None, None\n",
        "            continue\n",
        "        B = X_curr.size(0)\n",
        "        dX = (X_curr - X_prev).view(B, -1)\n",
        "        dY = (Y_curr - Y_prev).view(B, -1)\n",
        "        if B < 2:\n",
        "            er = 0.0\n",
        "        else:\n",
        "            cos_squares = [\n",
        "                F.cosine_similarity(\n",
        "                    dY[i].unsqueeze(0),\n",
        "                    dX[i].unsqueeze(0),\n",
        "                    dim=1, eps=1e-8\n",
        "                ).item() ** 2\n",
        "                for i in range(1, B)\n",
        "            ]\n",
        "            er = sum(cos_squares) / (2 * (B - 1))\n",
        "        er_scores[idx] = er\n",
        "        buf['prev_X'], buf['prev_Y'] = X_curr, Y_curr\n",
        "        buf['curr_X'], buf['curr_Y'] = None, None\n",
        "    return er_scores\n",
        "\n",
        "# ─── 2) Pruning Utilities ─────────────────────────────────────────────────\n",
        "class SkipFF(nn.Module):\n",
        "    def forward(self, hidden_states, input_tensor=None):\n",
        "        return input_tensor\n",
        "\n",
        "def prune_er_layers(model, er_scores, num_prune=4):\n",
        "    sorted_layers = sorted(er_scores.items(), key=lambda x: x[1], reverse=True)\n",
        "    prune_idxs = [idx+1 for idx, _ in sorted_layers[:num_prune]\n",
        "                  if idx+1 < len(model.roberta.encoder.layer)]\n",
        "    for idx in prune_idxs:\n",
        "        layer = model.roberta.encoder.layer[idx]\n",
        "        layer.intermediate.dense = nn.Identity()\n",
        "        layer.output = SkipFF()\n",
        "    return prune_idxs\n",
        "\n",
        "# ─── 3) LoRA Modules ──────────────────────────────────────────────────────\n",
        "class LoRA(nn.Module):\n",
        "    def __init__(self, W0, r=2, alpha=1.0):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"W0\", W0.clone().detach())\n",
        "        L, M = W0.shape\n",
        "        self.B = nn.Parameter(torch.randn(L, r) * 0.01)\n",
        "        self.A = nn.Parameter(torch.zeros(r, M))\n",
        "        self.scaling = alpha / r\n",
        "    def forward(self):\n",
        "        return self.W0 + self.scaling * (self.B @ self.A)\n",
        "\n",
        "def apply_lora_to_all_layers(model, r=2, alpha=1.0):\n",
        "    loras = {}\n",
        "    for idx, layer in enumerate(model.roberta.encoder.layer):\n",
        "        if not hasattr(layer.output, 'dense'):\n",
        "            continue\n",
        "        W0 = layer.output.dense.weight.data\n",
        "        lora = LoRA(W0, r, alpha).to(W0.device)\n",
        "        def fwd(x, layer=layer, lora=lora):\n",
        "            return F.linear(x, lora(), layer.output.dense.bias)\n",
        "        layer.output.dense.forward = fwd\n",
        "        loras[idx] = lora\n",
        "    return loras\n",
        "\n",
        "# ─── 4) STS-B Evaluation with Flattened References ─────────────────────────\n",
        "def evaluate_stsb(model, dataloader, device):\n",
        "    model.eval()\n",
        "    metric = evaluate.load(\"glue\", \"stsb\")\n",
        "    preds, refs = [], []\n",
        "    with torch.no_grad():\n",
        "        for batch in dataloader:\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "            )\n",
        "            # flatten predictions\n",
        "            p = out.logits.squeeze(-1).cpu().tolist()\n",
        "            preds.extend(p if isinstance(p, list) else [p])\n",
        "            # flatten references (handle [[5.0], [4.75], ...])\n",
        "            r = batch[\"labels\"].cpu().tolist()\n",
        "            # r might be list-of-lists or list-of-floats\n",
        "            for x in r:\n",
        "                if isinstance(x, (list, tuple, np.ndarray)):\n",
        "                    refs.append(float(x[0]))\n",
        "                else:\n",
        "                    refs.append(float(x))\n",
        "    return metric.compute(predictions=preds, references=refs)\n",
        "\n",
        "# ─── 5) Training Stages ────────────────────────────────────────────────────\n",
        "def full_finetuning(train_loader, dev_loader, device):\n",
        "    model = RobertaForSequenceClassification.from_pretrained(\n",
        "        \"roberta-base\", num_labels=1\n",
        "    ).to(device)\n",
        "    model.gradient_checkpointing_enable()\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=2e-2)\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6\n",
        "    )\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    hooks, activations = register_er_hooks(model)\n",
        "    last_er = None\n",
        "    for epoch in range(6):\n",
        "        er_sums, er_counts = defaultdict(float), defaultdict(int)\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device),\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "            batch_er = compute_batch_entropy(activations)\n",
        "            for idx, v in batch_er.items():\n",
        "                er_sums[idx] += v\n",
        "                er_counts[idx] += 1\n",
        "\n",
        "        epoch_er = {\n",
        "            idx: er_sums[idx] / er_counts[idx]\n",
        "            for idx in er_sums if er_counts[idx] > 0\n",
        "        }\n",
        "        print(f\"[Epoch {epoch+1}] ER:\", epoch_er)\n",
        "        last_er = epoch_er\n",
        "\n",
        "    metrics = evaluate_stsb(model, dev_loader, device)\n",
        "    print(f\"STS‑B Pearson: {metrics['pearson']:.4f}, Spearman: {metrics['spearmanr']:.4f}\")\n",
        "    remove_hooks(hooks)\n",
        "    return model, last_er\n",
        "\n",
        "def prune_and_finetuning(model, train_loader, dev_loader, device, er_scores):\n",
        "    prune_idxs = prune_er_layers(model, er_scores, num_prune=4)\n",
        "    print(\"Pruned layers:\", prune_idxs)\n",
        "    opt = torch.optim.AdamW(model.parameters(), lr=1e-5)\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*3\n",
        "    )\n",
        "    for epoch in range(5):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            out = model(\n",
        "                input_ids=batch[\"input_ids\"].to(device),\n",
        "                attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                labels=batch[\"labels\"].to(device),\n",
        "            )\n",
        "            out.loss.backward()\n",
        "            opt.step()\n",
        "            sched.step()\n",
        "\n",
        "        metrics = evaluate_stsb(model, dev_loader, device)\n",
        "        print(f\"[Prune Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}\")\n",
        "    return model\n",
        "\n",
        "def lora_only_finetuning(model, train_loader, dev_loader, device):\n",
        "    torch.cuda.empty_cache()\n",
        "    loras = apply_lora_to_all_layers(model)\n",
        "    for p in model.roberta.parameters(): p.requires_grad = False\n",
        "    for p in model.classifier.parameters(): p.requires_grad = True\n",
        "    for l in loras.values():\n",
        "        l.A.requires_grad = True\n",
        "        l.B.requires_grad = True\n",
        "\n",
        "    opt = torch.optim.AdamW(\n",
        "        list(model.classifier.parameters())\n",
        "        + [p for l in loras.values() for p in (l.A, l.B)],\n",
        "        lr=2e-5\n",
        "    )\n",
        "    sched = get_linear_schedule_with_warmup(\n",
        "        opt, num_warmup_steps=0, num_training_steps=len(train_loader)*6\n",
        "    )\n",
        "    scaler = GradScaler()\n",
        "\n",
        "    for epoch in range(6):\n",
        "        model.train()\n",
        "        for batch in train_loader:\n",
        "            opt.zero_grad()\n",
        "            with autocast():\n",
        "                out = model(\n",
        "                    input_ids=batch[\"input_ids\"].to(device),\n",
        "                    attention_mask=batch[\"attention_mask\"].to(device),\n",
        "                    labels=batch[\"labels\"].to(device),\n",
        "                )\n",
        "                scaler.scale(out.loss).backward()\n",
        "            scaler.step(opt)\n",
        "            scaler.update()\n",
        "            sched.step()\n",
        "\n",
        "        metrics = evaluate_stsb(model, dev_loader, device)\n",
        "        print(f\"[LoRA Epoch {epoch+1}] Pearson: {metrics['pearson']:.4f}\")\n",
        "\n",
        "# ─── 6) Main Entrypoint ────────────────────────────────────────────────────\n",
        "def main():\n",
        "    seed = 42\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "    tokenizer = RobertaTokenizerFast.from_pretrained(\"roberta-base\")\n",
        "    train_ds = load_dataset(\"glue\", \"stsb\", split=\"train\").shuffle(seed)\n",
        "    dev_ds   = load_dataset(\"glue\", \"stsb\", split=\"validation\")\n",
        "\n",
        "    def preprocess(ex):\n",
        "        return tokenizer(\n",
        "            ex[\"sentence1\"], ex[\"sentence2\"],\n",
        "            truncation=True, padding=\"max_length\", max_length=128\n",
        "        )\n",
        "\n",
        "    train_ds = train_ds.map(preprocess, batched=True)\n",
        "    dev_ds   = dev_ds.map(preprocess, batched=True)\n",
        "\n",
        "    # Cast labels to flat float32\n",
        "    train_ds = train_ds.map(lambda x: {\"labels\": float(x[\"label\"])}, batched=False)\n",
        "    dev_ds   = dev_ds.map(lambda x: {\"labels\": float(x[\"label\"])}, batched=False)\n",
        "\n",
        "    train_ds = train_ds.remove_columns([\"sentence1\", \"sentence2\", \"label\", \"idx\"])\n",
        "    dev_ds   = dev_ds.remove_columns([\"sentence1\", \"sentence2\", \"label\", \"idx\"])\n",
        "\n",
        "    train_ds.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "    dev_ds.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
        "\n",
        "    collator     = DataCollatorWithPadding(tokenizer, padding=\"max_length\", max_length=128)\n",
        "    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,  collate_fn=collator)\n",
        "    dev_loader   = DataLoader(dev_ds,   batch_size=16, shuffle=False, collate_fn=collator)\n",
        "\n",
        "    model, er_scores = full_finetuning(train_loader, dev_loader, device)\n",
        "    #model = prune_and_finetuning(model, train_loader, dev_loader, device, er_scores)\n",
        "    #lora_only_finetuning(model, train_loader, dev_loader, device)\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "id": "-ae7zj04U5u5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "U3wDNkD86y-D"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}