{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "da814f23-73e4-49c1-a1c3-820715eeeab0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
    "import numpy as np\n",
    "import evaluate\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "from sklearn.metrics import accuracy_score\n",
    "from torch.nn import CrossEntropyLoss\n",
    "import copy\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8818ea85-9054-44b0-93bc-ec70434a8696",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=2)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "BATCH_SIZE = 16\n",
    "\n",
    "DEVICE = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
    "DEVICE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d4a4fad-ccb7-4921-b189-f7bdf5a09973",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    print(f\"Setting seed: {seed}\")\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    import random\n",
    "\n",
    "    # Set seeds\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    # Ensure deterministic behavior\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "seed = 2025"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0de01b89-569f-4576-a825-b633328e1edc",
   "metadata": {},
   "source": [
    "## model & data preparation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc258c6b-cfb1-47a9-955a-51cd83167b4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-large-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(28996, 1024, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 1024)\n",
       "      (token_type_embeddings): Embedding(2, 1024)\n",
       "      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-23): 24 x BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSdpaSelfAttention(\n",
       "              (query): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (key): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (value): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "              (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
       "            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       "  (classifier): Linear(in_features=1024, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Pretrained model\n",
    "checkpoint = \"google-bert/bert-large-cased\"\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "351b3e9a-d3bb-48bd-89ff-476ec62d8c07",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|████████████████████████| 3668/3668 [00:00<00:00, 13845.25 examples/s]\n",
      "Map: 100%|██████████████████████████| 408/408 [00:00<00:00, 10894.43 examples/s]\n",
      "Map: 100%|████████████████████████| 1725/1725 [00:00<00:00, 14335.31 examples/s]\n"
     ]
    }
   ],
   "source": [
    "###### loading data\n",
    "raw_dataset = load_dataset(\"glue\", \"mrpc\")\n",
    "raw_dataset\n",
    "\n",
    "# Create a tokenized dataset\n",
    "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"sentence1\"], examples[\"sentence2\"],\n",
    "                     padding=\"max_length\", truncation=True, max_length=128)\n",
    "\n",
    "tokenized_datasets = raw_dataset.map(tokenize_function, batched=True)\n",
    "\n",
    "tokenized_datasets = tokenized_datasets.remove_columns([\"idx\",\"sentence1\", \"sentence2\"])\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
    "tokenized_datasets = tokenized_datasets.with_format(\"torch\")\n",
    "# tokenized_datasets\n",
    "\n",
    "# small_train_dataset = tokenized_datasets[\"train\"].select(range(3600))\n",
    "# small_eval_dataset = tokenized_datasets[\"validation\"].select(range(400))\n",
    "# small_test_dataset = tokenized_datasets[\"test\"].select(range(1700))\n",
    "\n",
    "num_training_data = 1280\n",
    "train_loader = DataLoader(dataset=tokenized_datasets[\"train\"].select(range(num_training_data)),\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True)\n",
    "\n",
    "\n",
    "# Create DataLoader for the test dataset\n",
    "test_loader = DataLoader(dataset=tokenized_datasets[\"test\"].select(range(num_training_data//4)),\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3baa013c-3d47-4117-a043-818edb8ba1b8",
   "metadata": {},
   "source": [
    "## Self-defined trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7dd006e0-ecde-45da-9a4d-22929c145d54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to calculate accuracy\n",
    "def compute_accuracy(model, data_loader, device):\n",
    "    model.eval()  # Set the model to evaluation mode\n",
    "    true_labels = []\n",
    "    predictions = []\n",
    "\n",
    "    with torch.no_grad():  # Disable gradient calculation\n",
    "        for batch in data_loader:\n",
    "            # Move batch to the correct device\n",
    "            batch = {k: v.to(device) for k, v in batch.items()}\n",
    "            \n",
    "            # Get model predictions\n",
    "            outputs = model(**batch)\n",
    "            logits = outputs.logits\n",
    "            \n",
    "            # Get predicted class (highest logit value)\n",
    "            preds = torch.argmax(logits, dim=-1)\n",
    "            \n",
    "            # Store true labels and predictions\n",
    "            true_labels.extend(batch[\"labels\"].cpu().numpy())\n",
    "            predictions.extend(preds.cpu().numpy())\n",
    "    \n",
    "    # Compute accuracy\n",
    "    accuracy = accuracy_score(true_labels, predictions)\n",
    "    return accuracy\n",
    "    \n",
    "def compute_loss(model, data_loader, device):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        loss_total = 0\n",
    "        for batch in data_loader:\n",
    "            # Move batch to the correct device\n",
    "            batch = {k: v.to(device) for k, v in batch.items()}\n",
    "            \n",
    "            # Get model predictions\n",
    "            outputs = model(**batch)\n",
    "\n",
    "            loss_total += outputs.loss.item()\n",
    "        return loss_total/len(data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d8e44157-030f-4e89-b4ea-82c4290f40bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\"gamma\":15,\n",
    "       \"lamba\": 1e-3,\n",
    "       \"inner_lr\":5e-4,\n",
    "       \"learning_rate\":5e-4,\n",
    "       \"tokenized_train_dataset\": tokenized_datasets[\"train\"],\n",
    "       \"tokenized_test_dataset\": tokenized_datasets[\"test\"],\n",
    "       \"prop_train\": 0.66,\n",
    "       \"max_time\": 1800,\n",
    "       \"num_inner_step\":1\n",
    "       }\n",
    "\n",
    "def train(num_epochs, model, optimizer, train_loader, test_loader, args = args, device=DEVICE ):\n",
    "    max_time = args[\"max_time\"]\n",
    "    # record\n",
    "    tr_acc=[compute_accuracy(model, train_loader, device)]\n",
    "    tr_loss=[compute_loss(model, train_loader, device)]\n",
    "    \n",
    "    test_acc=[compute_accuracy(model, test_loader, device)]\n",
    "    test_loss=[compute_loss(model, test_loader, device)]\n",
    "    \n",
    "    time_stamp=[0]\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        start_time = time.time()\n",
    "        if time_stamp[-1]>=max_time:\n",
    "            break\n",
    "        model.train()\n",
    "        print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
    "        for batch in train_loader:\n",
    "            # Move data to the correct device\n",
    "            batch = {k: v.to(device) for k, v in batch.items()}\n",
    "            \n",
    "            # Forward pass\n",
    "            outputs = model(**batch)\n",
    "            loss = outputs.loss # automatically cross entropy\n",
    "    \n",
    "            # Backward pass and optimization\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            loss.backward()\n",
    "\n",
    "            # UPDATE MODEL PARAMETERS\n",
    "            optimizer.step()\n",
    "\n",
    "        # recording\n",
    "        train_time = (time.time() - start_time)\n",
    "        time_stamp.append(time_stamp[-1]+train_time)\n",
    "\n",
    "        tr_acc.append(compute_accuracy(model, train_loader, device))\n",
    "        tr_loss.append(compute_loss(model, train_loader, device))\n",
    "        \n",
    "        test_acc.append(compute_accuracy(model, test_loader, device))\n",
    "        test_loss.append(compute_loss(model, test_loader, device))\n",
    "        \n",
    "        print('Time elapsed: %.2f min' % (time_stamp[-1]/60))\n",
    "        print(f\"Epoch {epoch+1} Training Loss: {tr_loss[-1]}\")\n",
    "        print(f\"Epoch {epoch+1} Training Accuracy: {tr_acc[-1]}\")\n",
    "        \n",
    "    return (tr_acc,tr_loss), (test_acc,test_loss), (time_stamp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeb6722c-034f-4755-8c0e-56a56fddc82b",
   "metadata": {},
   "source": [
    "## Trian BiDoRa PBGD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ca23e675-6587-4235-b5ff-1487d2431d95",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_regularization(B, A, device = DEVICE):\n",
    "        V = torch.matmul(B, A)  # V = B @ A\n",
    "        delta_V = torch.zeros_like(V)  # Placeholder for Delta V, update based on the iteration context\n",
    "        reg_term = torch.norm((V + delta_V).T @ (V + delta_V) - torch.eye(V.shape[1], device=device), p='fro')**2\n",
    "        # reg_term = torch.norm(V.T @ V - torch.eye(V.shape[1], device=device), p='fro')**2\n",
    "        return reg_term\n",
    "\n",
    "def compute_val_loss(model,val_loader, device=DEVICE):\n",
    "    model.train()\n",
    "    # Validation Loss Computation\n",
    "    batch = next(iter(val_loader))\n",
    "\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "            \n",
    "    # Forward pass\n",
    "    outputs = model(**batch)\n",
    "    loss = outputs.loss # automatically cross entropy\n",
    "    return loss #/batch.size(0)\n",
    "\n",
    "def freeze_layers(model, freeze_lora_magnitude=True, freeze_lora_embedding=True):\n",
    "    # Freeze lora_magnitude_vector\n",
    "    if freeze_lora_magnitude:\n",
    "        for name, param in model.named_parameters():\n",
    "            if 'lora_magnitude_vector' in name:\n",
    "                param.requires_grad = False\n",
    "            if 'lora_embedding_A' in name or 'lora_embedding_B' in name:\n",
    "                param.requires_grad = True\n",
    "    \n",
    "    # Freeze lora_embedding_A and lora_embedding_B\n",
    "    if freeze_lora_embedding:\n",
    "        for name, param in model.named_parameters():\n",
    "            if 'lora_magnitude_vector' in name:\n",
    "                param.requires_grad = True\n",
    "            if 'lora_embedding_A' in name or 'lora_embedding_B' in name:\n",
    "                param.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0e4c4bbc-e53e-4c3f-b91c-0beb6d72f8a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\"gamma\":15,\n",
    "       \"lamba\": 1e-3,\n",
    "       \"inner_lr\":5e-4,\n",
    "       \"learning_rate\":5e-4,\n",
    "       \"tokenized_train_dataset\": tokenized_datasets[\"train\"],\n",
    "       \"tokenized_test_dataset\": tokenized_datasets[\"test\"],\n",
    "       \"prop_train\": 0.66,\n",
    "       \"max_time\": 1800,\n",
    "       \"num_inner_step\":1\n",
    "       }\n",
    "\n",
    "def train_bidora_penalty(num_epochs, model, args = args, device=DEVICE, penalty_term=True):\n",
    "    gamma=args[\"gamma\"]\n",
    "    lamba=args[\"lamba\"]\n",
    "    inner_lr=args[\"inner_lr\"]\n",
    "    learning_rate=args[\"learning_rate\"]\n",
    "    tokenized_train_dataset=args[\"tokenized_train_dataset\"]\n",
    "    tokenized_test_dataset=args[\"tokenized_test_dataset\"]\n",
    "    prop_train=args[\"prop_train\"]\n",
    "    max_time = args[\"max_time\"]\n",
    "    num_inner_step = args[\"num_inner_step\"]\n",
    "    \n",
    "    train_size = int(prop_train * len(tokenized_train_dataset))\n",
    "    val_size = len(tokenized_train_dataset) - train_size\n",
    "    ab_train_dataset = tokenized_train_dataset.select(range(train_size))\n",
    "    m_train_dataset = tokenized_train_dataset.select(range(train_size,train_size+val_size))\n",
    "    # Create DataLoaders for training and validation\n",
    "    train_loader = DataLoader(dataset=ab_train_dataset,\n",
    "                              batch_size=BATCH_SIZE,\n",
    "                              shuffle=True)\n",
    "    \n",
    "    val_loader = DataLoader(dataset=m_train_dataset,\n",
    "                                   batch_size=BATCH_SIZE,\n",
    "                                   shuffle=False)\n",
    "\n",
    "    test_loader = DataLoader(dataset=tokenized_test_dataset,\n",
    "                                   batch_size=BATCH_SIZE,\n",
    "                                   shuffle=False)\n",
    "    # record\n",
    "    tr_acc=[compute_accuracy(model, train_loader, device)]\n",
    "    tr_loss=[compute_loss(model, train_loader, device)]\n",
    "    \n",
    "    test_acc=[compute_accuracy(model, test_loader, device)]\n",
    "    test_loss=[compute_loss(model, test_loader, device)]\n",
    "    \n",
    "    time_stamp=[0]\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        start_time = time.time()\n",
    "        # model.train()\n",
    "        print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
    "\n",
    "        if time_stamp[-1]>=max_time:\n",
    "            break\n",
    "        \n",
    "        total_loss = 0\n",
    "        for batch_val in val_loader:\n",
    "            batch_val = {k: v.to(device) for k, v in batch_val.items()}\n",
    "\n",
    "            # Phase 1: Update A and B, freeze m\n",
    "            freeze_layers(model, freeze_lora_magnitude=True, freeze_lora_embedding=False)\n",
    "            \n",
    "            optimizer_ab = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=inner_lr)\n",
    "            \n",
    "            ### LL inner loop\n",
    "            if penalty_term:\n",
    "                model_LL = copy.deepcopy(model) # it has been freezed accordingly\n",
    "                optimizer_LL = torch.optim.Adam(filter(lambda p: p.requires_grad, model_LL.parameters()), lr=inner_lr)\n",
    "                \n",
    "            for i in range(num_inner_step):#(int(prop_train/(1-prop_train))): # as number of train data is 9 times more than validation data\n",
    "                model.train()\n",
    "\n",
    "                batch = next(iter(train_loader))\n",
    "                batch = {k: v.to(device) for k, v in batch.items()}\n",
    "\n",
    "                # Forward and backward pass\n",
    "                outputs = model(**batch)\n",
    "                loss_train = outputs.loss\n",
    "                \n",
    "                # loss_val = compute_val_loss(model)\n",
    "                outputs_val = model(**batch_val)\n",
    "                loss_val = outputs_val.loss\n",
    "\n",
    "                loss_reg = 0\n",
    "                for name, param in model.named_parameters():\n",
    "                    for name2, param2 in model.named_parameters():\n",
    "                        if \"lora_embedding_A\" in name and \"lora_embedding_B\" in name2 and name[7]==name2[7]: # 7 is the layer number\n",
    "                            loss_reg += compute_regularization(param, param2)\n",
    "                # print(loss_train,loss_val,loss_reg)\n",
    "                loss = 1/gamma* loss_val  + loss_train + lamba* loss_reg\n",
    "\n",
    "\n",
    "                optimizer_ab.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer_ab.step()\n",
    "\n",
    "                if penalty_term:\n",
    "                    model_LL.train()\n",
    "                    outputs = model_LL(**batch)\n",
    "                    loss_train_LL = outputs.loss\n",
    "\n",
    "                    optimizer_LL.zero_grad()\n",
    "                    loss_train_LL.backward()\n",
    "\n",
    "                    optimizer_LL.step()\n",
    "\n",
    "\n",
    "            # Phase 2:\n",
    "            # Freeze all ab, unfreeze m\n",
    "            freeze_layers(model, freeze_lora_magnitude=False, freeze_lora_embedding=True)\n",
    "\n",
    "            optimizer_m = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)\n",
    "\n",
    "            # upper loss\n",
    "            outputs_m = model(**batch_val)\n",
    "            loss_m = outputs_m.loss\n",
    "\n",
    "\n",
    "            #penalty term\n",
    "            if penalty_term:\n",
    "                freeze_layers(model_LL, freeze_lora_magnitude=False, freeze_lora_embedding=True)\n",
    "                optimizer_m = torch.optim.Adam(filter(lambda p: p.requires_grad, list(model.parameters()) + list(model_LL.parameters())),lr=1e-3)\n",
    "\n",
    "                batch = next(iter(train_loader))\n",
    "                batch = {k: v.to(device) for k, v in batch.items()}\n",
    "                outputs_LL = model_LL(**batch)\n",
    "                outputs = model(**batch)\n",
    "                \n",
    "                # Forward and backward pass\n",
    "                loss_train_LL = outputs_LL.loss\n",
    "                loss_train = outputs.loss\n",
    "\n",
    "                loss_m += gamma*(loss_train-loss_train_LL)\n",
    "\n",
    "            \n",
    "            total_loss += loss_m\n",
    "            \n",
    "            optimizer_m.zero_grad()\n",
    "            loss_m.backward()\n",
    "            optimizer_m.step()\n",
    "\n",
    "        # recording\n",
    "        train_time = (time.time() - start_time)\n",
    "        time_stamp.append(time_stamp[-1]+train_time)\n",
    "\n",
    "        tr_acc.append(compute_accuracy(model, train_loader, device))\n",
    "        tr_loss.append(compute_loss(model, train_loader, device))\n",
    "        \n",
    "        test_acc.append(compute_accuracy(model, test_loader, device))\n",
    "        test_loss.append(compute_loss(model, test_loader, device))\n",
    "        \n",
    "        print('Time elapsed: %.2f min' % (time_stamp[-1]/60))\n",
    "        print(f\"Epoch {epoch+1} Training Loss: {tr_loss[-1]}\")\n",
    "        print(f\"Epoch {epoch+1} Training Accuracy: {tr_acc[-1]}\")\n",
    "        \n",
    "    return (tr_acc,tr_loss), (test_acc,test_loss), (time_stamp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "215ae97e-11d9-45a0-b32e-bcc3ccac6d39",
   "metadata": {},
   "source": [
    "## Train BiDoRa-Origin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "071a54df-25ac-4eda-a027-0fde0b46d66f",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = {\"gamma\":15,\n",
    "       \"lamba\": 1e-3,\n",
    "       \"inner_lr\":5e-4,\n",
    "       \"learning_rate\":5e-4,\n",
    "       \"tokenized_train_dataset\": tokenized_datasets[\"train\"],\n",
    "       \"tokenized_test_dataset\": tokenized_datasets[\"test\"],\n",
    "       \"prop_train\": 0.66,\n",
    "       \"max_time\": 1800,\n",
    "       \"num_inner_step\":1\n",
    "       }\n",
    "\n",
    "def train_bidora_approx(num_epochs, model, epsilon=1e-3, args = args, device=DEVICE,xi = 0.1):\n",
    "    # This is the algorithm that was used in the biDoRa paper\n",
    "    gamma=args[\"gamma\"]\n",
    "    lamba=args[\"lamba\"]\n",
    "    inner_lr=args[\"inner_lr\"]\n",
    "    learning_rate=args[\"learning_rate\"]\n",
    "    tokenized_train_dataset=args[\"tokenized_train_dataset\"]\n",
    "    tokenized_test_dataset=args[\"tokenized_test_dataset\"]\n",
    "    prop_train=args[\"prop_train\"]\n",
    "    max_time = args[\"max_time\"]\n",
    "    num_inner_step = args[\"num_inner_step\"]\n",
    "    \n",
    "    train_size = int(prop_train * len(tokenized_train_dataset))\n",
    "    val_size = len(tokenized_train_dataset) - train_size\n",
    "    ab_train_dataset = tokenized_train_dataset.select(range(train_size))\n",
    "    m_train_dataset = tokenized_train_dataset.select(range(train_size,train_size+val_size))\n",
    "    # Create DataLoaders for training and validation\n",
    "    train_loader = DataLoader(dataset=ab_train_dataset,\n",
    "                              batch_size=BATCH_SIZE,\n",
    "                              shuffle=True)\n",
    "    \n",
    "    val_loader = DataLoader(dataset=m_train_dataset,\n",
    "                                   batch_size=BATCH_SIZE,\n",
    "                                   shuffle=False)\n",
    "\n",
    "    test_loader = DataLoader(dataset=tokenized_test_dataset,\n",
    "                                   batch_size=BATCH_SIZE,\n",
    "                                   shuffle=False)\n",
    "    # record\n",
    "    tr_acc=[compute_accuracy(model, train_loader, device)]\n",
    "    tr_loss=[compute_loss(model, train_loader, device)]\n",
    "    \n",
    "    test_acc=[compute_accuracy(model, test_loader, device)]\n",
    "    test_loss=[compute_loss(model, test_loader, device)]\n",
    "    \n",
    "    time_stamp=[0]\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        start_time = time.time()\n",
    "        # model.train()\n",
    "        print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
    "        if time_stamp[-1]>=max_time:\n",
    "            break\n",
    "        \n",
    "        total_loss = 0\n",
    "        for batch_val in val_loader:\n",
    "            batch_val = {k: v.to(device) for k, v in batch_val.items()}\n",
    "\n",
    "            # Phase 1: Update A and B, freeze m\n",
    "            freeze_layers(model, freeze_lora_magnitude=True, freeze_lora_embedding=False)\n",
    "            \n",
    "            optimizer_ab = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=inner_lr)\n",
    "            \n",
    "            ### LL inner loop\n",
    "            for i in range(num_inner_step):#(int(prop_train/(1-prop_train))): # as number of train data is 9 times more than validation data\n",
    "                model.train()\n",
    "\n",
    "                batch = next(iter(train_loader))\n",
    "                batch = {k: v.to(device) for k, v in batch.items()}\n",
    "\n",
    "                # Forward and backward pass\n",
    "                outputs = model(**batch)\n",
    "                loss_train = outputs.loss\n",
    "                \n",
    "                # loss_val = compute_val_loss(model)\n",
    "                outputs_val = model(**batch_val)\n",
    "                loss_val = outputs_val.loss\n",
    "\n",
    "                loss_reg = 0\n",
    "                for name, param in model.named_parameters():\n",
    "                    for name2, param2 in model.named_parameters():\n",
    "                        if \"lora_embedding_A\" in name and \"lora_embedding_B\" in name2 and name[7]==name2[7]: # 7 is the layer number\n",
    "                            loss_reg += compute_regularization(param, param2)\n",
    "                # print(loss_train,loss_val,loss_reg)\n",
    "                loss = 1/gamma* loss_val  + loss_train + lamba* loss_reg\n",
    "\n",
    "\n",
    "                optimizer_ab.zero_grad()\n",
    "                loss.backward()\n",
    "                optimizer_ab.step()\n",
    "\n",
    "            # Phase 2:\n",
    "            # Freeze all ab, unfreeze m\n",
    "            freeze_layers(model, freeze_lora_magnitude=False, freeze_lora_embedding=True)\n",
    "\n",
    "            optimizer_m = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)\n",
    "\n",
    "            # upper loss\n",
    "            outputs_m = model(**batch_val)\n",
    "            loss_m = outputs_m.loss\n",
    "\n",
    "            grad_v = torch.autograd.grad(loss_m, filter(lambda p: p.requires_grad, model.parameters()), create_graph=True)\n",
    "\n",
    "            # keep the m params at this status\n",
    "            original_model_params = [param.clone() for param in filter(lambda p: p.requires_grad, model.parameters())]\n",
    "            \n",
    "            # Compute perturbed M^+ and M^-\n",
    "            M_plus = [param + epsilon * grad for param, grad in zip(filter(lambda p: p.requires_grad, model.parameters()), grad_v)]\n",
    "            M_minus = [param - epsilon * grad for param, grad in zip(filter(lambda p: p.requires_grad, model.parameters()), grad_v)]\n",
    "            \n",
    "            ### Compute L_tr(M^+) and L_tr(M^-)\n",
    "            # Replace model parameters with M^+ and compute training loss\n",
    "            with torch.no_grad():\n",
    "                for param, new_param in zip(filter(lambda p: p.requires_grad, model.parameters()), M_plus):\n",
    "                    param.data.copy_(new_param.data)\n",
    "            outputs_plus = model(**batch)\n",
    "            train_loss_plus = outputs_plus.loss\n",
    "        \n",
    "            # Replace model parameters with M^- and compute training loss\n",
    "            with torch.no_grad():\n",
    "                for param, new_param in zip(filter(lambda p: p.requires_grad, model.parameters()), M_minus):\n",
    "                    param.data.copy_(new_param.data)\n",
    "            outputs_minus = model(**batch)\n",
    "            train_loss_minus = outputs_minus.loss\n",
    "            \n",
    "            # Reset model parameters to original M (unperturbed state)\n",
    "            with torch.no_grad():\n",
    "                for param, original_param in zip(filter(lambda p: p.requires_grad, model.parameters()), original_model_params):\n",
    "                    param.data.copy_(original_param.data)\n",
    "        \n",
    "            # Compute gradients of L_tr(M^+) and L_tr(M^-)\n",
    "            grad_tr_plus = torch.autograd.grad(train_loss_plus, filter(lambda p: p.requires_grad, model.parameters()), create_graph=True)\n",
    "            grad_tr_minus = torch.autograd.grad(train_loss_minus, filter(lambda p: p.requires_grad, model.parameters()), create_graph=True)\n",
    "            \n",
    "            # Compute second-order approximation\n",
    "            grad_approx = []\n",
    "            for grad_plus, grad_minus in zip(grad_tr_plus, grad_tr_minus):\n",
    "                grad_approx.append((grad_plus - grad_minus) / (2 * epsilon))  # Central difference\n",
    "            \n",
    "            # Compute the final gradient approximation\n",
    "            final_grad = []\n",
    "            for grad_v, grad_m in zip(grad_v, grad_approx):\n",
    "                final_grad.append(grad_v - xi * grad_m)  # Final gradient\n",
    "            \n",
    "            # Zero the gradients before backward pass\n",
    "            optimizer_m.zero_grad()\n",
    "            \n",
    "            # Manually accumulate gradients for each parameter\n",
    "            for param, grad in zip(filter(lambda p: p.requires_grad, model.parameters()), final_grad):\n",
    "                param.grad = grad  # Set the manually computed gradient for each parameter\n",
    "            \n",
    "            # Perform the optimizer step to update the parameters using Adam\n",
    "            optimizer_m.step()\n",
    "\n",
    "        # recording\n",
    "        train_time = (time.time() - start_time)\n",
    "        time_stamp.append(time_stamp[-1]+train_time)\n",
    "\n",
    "        tr_acc.append(compute_accuracy(model, train_loader, device))\n",
    "        tr_loss.append(compute_loss(model, train_loader, device))\n",
    "        \n",
    "        test_acc.append(compute_accuracy(model, test_loader, device))\n",
    "        test_loss.append(compute_loss(model, test_loader, device))\n",
    "        \n",
    "        print('Time elapsed: %.2f min' % (time_stamp[-1]/60))\n",
    "        print(f\"Epoch {epoch+1} Training Loss: {tr_loss[-1]}\")\n",
    "        print(f\"Epoch {epoch+1} Training Accuracy: {tr_acc[-1]}\")\n",
    "        \n",
    "    return (tr_acc,tr_loss), (test_acc,test_loss), (time_stamp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce656149-84a9-4277-b8ef-0649a3dbb629",
   "metadata": {},
   "source": [
    "## Train DoRa, BiDoRa"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d72715e-dfb9-4848-bef1-aac02766fd10",
   "metadata": {},
   "source": [
    "### DoRa"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5bb11243-a79f-4c33-b629-9ed89d5b7730",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-large-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting seed: 0\n",
      "random_seed: 0\n",
      "========Iteration: 0 =========\n",
      "Epoch 1/15\n",
      "Time elapsed: 0.66 min\n",
      "Epoch 1 Training Loss: 0.5340764075517654\n",
      "Epoch 1 Training Accuracy: 0.75859375\n",
      "Epoch 2/15\n",
      "Time elapsed: 1.32 min\n",
      "Epoch 2 Training Loss: 0.36950838342309\n",
      "Epoch 2 Training Accuracy: 0.8546875\n",
      "Epoch 3/15\n",
      "Time elapsed: 1.97 min\n",
      "Epoch 3 Training Loss: 0.21372515708208084\n",
      "Epoch 3 Training Accuracy: 0.93984375\n",
      "Epoch 4/15\n",
      "Time elapsed: 2.62 min\n",
      "Epoch 4 Training Loss: 0.10377543163485825\n",
      "Epoch 4 Training Accuracy: 0.97265625\n",
      "Epoch 5/15\n",
      "Time elapsed: 3.27 min\n",
      "Epoch 5 Training Loss: 0.0918227544403635\n",
      "Epoch 5 Training Accuracy: 0.96796875\n",
      "Epoch 6/15\n",
      "Time elapsed: 3.92 min\n",
      "Epoch 6 Training Loss: 0.03835576356505044\n",
      "Epoch 6 Training Accuracy: 0.9859375\n",
      "Epoch 7/15\n",
      "Time elapsed: 4.56 min\n",
      "Epoch 7 Training Loss: 0.013878559251315892\n",
      "Epoch 7 Training Accuracy: 0.996875\n",
      "Epoch 8/15\n",
      "Time elapsed: 5.21 min\n",
      "Epoch 8 Training Loss: 0.0037915125547442587\n",
      "Epoch 8 Training Accuracy: 0.9984375\n",
      "Epoch 9/15\n",
      "Time elapsed: 5.86 min\n",
      "Epoch 9 Training Loss: 0.010763423240496195\n",
      "Epoch 9 Training Accuracy: 0.996875\n",
      "Epoch 10/15\n",
      "Time elapsed: 6.51 min\n",
      "Epoch 10 Training Loss: 0.0077667911944445224\n",
      "Epoch 10 Training Accuracy: 0.99921875\n",
      "Epoch 11/15\n",
      "Time elapsed: 6.82 min\n",
      "Epoch 11 Training Loss: 0.015710535048492603\n",
      "Epoch 11 Training Accuracy: 0.99609375\n",
      "Epoch 12/15\n",
      "Time elapsed: 7.13 min\n",
      "Epoch 12 Training Loss: 0.01924759293688112\n",
      "Epoch 12 Training Accuracy: 0.99453125\n",
      "Epoch 13/15\n",
      "Time elapsed: 7.79 min\n",
      "Epoch 13 Training Loss: 0.008833059579046676\n",
      "Epoch 13 Training Accuracy: 0.99765625\n",
      "Epoch 14/15\n",
      "Time elapsed: 8.43 min\n",
      "Epoch 14 Training Loss: 0.004080896378218313\n",
      "Epoch 14 Training Accuracy: 0.99921875\n",
      "Epoch 15/15\n",
      "Time elapsed: 9.09 min\n",
      "Epoch 15 Training Loss: 0.005632406488439301\n",
      "Epoch 15 Training Accuracy: 0.99921875\n",
      "Test accuracy DoRA finetune: 0.80\n",
      "Setting seed: 1\n",
      "random_seed: 1\n",
      "========Iteration: 1 =========\n",
      "Epoch 1/15\n",
      "Time elapsed: 0.65 min\n",
      "Epoch 1 Training Loss: 0.5150769166648388\n",
      "Epoch 1 Training Accuracy: 0.73515625\n",
      "Epoch 2/15\n",
      "Time elapsed: 1.31 min\n",
      "Epoch 2 Training Loss: 0.3770384628325701\n",
      "Epoch 2 Training Accuracy: 0.85234375\n",
      "Epoch 3/15\n",
      "Time elapsed: 1.96 min\n",
      "Epoch 3 Training Loss: 0.3327138228341937\n",
      "Epoch 3 Training Accuracy: 0.8484375\n",
      "Epoch 4/15\n",
      "Time elapsed: 2.61 min\n",
      "Epoch 4 Training Loss: 0.15560125963529572\n",
      "Epoch 4 Training Accuracy: 0.94296875\n",
      "Epoch 5/15\n",
      "Time elapsed: 3.27 min\n",
      "Epoch 5 Training Loss: 0.05615328258136287\n",
      "Epoch 5 Training Accuracy: 0.984375\n",
      "Epoch 6/15\n",
      "Time elapsed: 3.92 min\n",
      "Epoch 6 Training Loss: 0.02520382119982969\n",
      "Epoch 6 Training Accuracy: 0.99296875\n",
      "Epoch 7/15\n",
      "Time elapsed: 4.57 min\n",
      "Epoch 7 Training Loss: 0.014647382686962373\n",
      "Epoch 7 Training Accuracy: 0.996875\n",
      "Epoch 8/15\n",
      "Time elapsed: 5.23 min\n",
      "Epoch 8 Training Loss: 0.007443890433205524\n",
      "Epoch 8 Training Accuracy: 0.99921875\n",
      "Epoch 9/15\n",
      "Time elapsed: 5.88 min\n",
      "Epoch 9 Training Loss: 0.007074823108996498\n",
      "Epoch 9 Training Accuracy: 0.9984375\n",
      "Epoch 10/15\n",
      "Time elapsed: 6.53 min\n",
      "Epoch 10 Training Loss: 0.007198427404364338\n",
      "Epoch 10 Training Accuracy: 0.99921875\n",
      "Epoch 11/15\n",
      "Time elapsed: 6.84 min\n",
      "Epoch 11 Training Loss: 0.015655236573365983\n",
      "Epoch 11 Training Accuracy: 0.99375\n",
      "Epoch 12/15\n",
      "Time elapsed: 7.15 min\n",
      "Epoch 12 Training Loss: 0.005323883180244593\n",
      "Epoch 12 Training Accuracy: 0.99921875\n",
      "Epoch 13/15\n",
      "Time elapsed: 7.46 min\n",
      "Epoch 13 Training Loss: 0.009399927510821726\n",
      "Epoch 13 Training Accuracy: 0.9984375\n",
      "Epoch 14/15\n",
      "Time elapsed: 8.11 min\n",
      "Epoch 14 Training Loss: 0.0024222789514169564\n",
      "Epoch 14 Training Accuracy: 0.99921875\n",
      "Epoch 15/15\n",
      "Time elapsed: 8.76 min\n",
      "Epoch 15 Training Loss: 0.0012617459445209533\n",
      "Epoch 15 Training Accuracy: 0.99921875\n",
      "Test accuracy DoRA finetune: 0.80\n",
      "Setting seed: 2\n",
      "random_seed: 2\n",
      "========Iteration: 2 =========\n",
      "Epoch 1/15\n",
      "Time elapsed: 0.65 min\n",
      "Epoch 1 Training Loss: 0.7643219739198684\n",
      "Epoch 1 Training Accuracy: 0.33203125\n",
      "Epoch 2/15\n",
      "Time elapsed: 1.31 min\n",
      "Epoch 2 Training Loss: 0.6448049794882535\n",
      "Epoch 2 Training Accuracy: 0.66796875\n",
      "Epoch 3/15\n",
      "Time elapsed: 1.96 min\n",
      "Epoch 3 Training Loss: 0.6468218445777894\n",
      "Epoch 3 Training Accuracy: 0.66796875\n",
      "Epoch 4/15\n",
      "Time elapsed: 2.61 min\n",
      "Epoch 4 Training Loss: 0.6467950619757176\n",
      "Epoch 4 Training Accuracy: 0.66796875\n",
      "Epoch 5/15\n",
      "Time elapsed: 3.26 min\n",
      "Epoch 5 Training Loss: 0.6699320007115602\n",
      "Epoch 5 Training Accuracy: 0.66796875\n",
      "Epoch 6/15\n",
      "Time elapsed: 3.92 min\n",
      "Epoch 6 Training Loss: 0.6363665070384741\n",
      "Epoch 6 Training Accuracy: 0.66796875\n",
      "Epoch 7/15\n",
      "Time elapsed: 4.57 min\n",
      "Epoch 7 Training Loss: 0.6463688340038061\n",
      "Epoch 7 Training Accuracy: 0.66796875\n",
      "Epoch 8/15\n",
      "Time elapsed: 5.22 min\n",
      "Epoch 8 Training Loss: 0.6389395613223314\n",
      "Epoch 8 Training Accuracy: 0.66796875\n",
      "Epoch 9/15\n",
      "Time elapsed: 5.88 min\n",
      "Epoch 9 Training Loss: 0.6428885467350482\n",
      "Epoch 9 Training Accuracy: 0.66796875\n",
      "Epoch 10/15\n",
      "Time elapsed: 6.53 min\n",
      "Epoch 10 Training Loss: 0.6448517203330993\n",
      "Epoch 10 Training Accuracy: 0.66796875\n",
      "Epoch 11/15\n",
      "Time elapsed: 7.18 min\n",
      "Epoch 11 Training Loss: 0.6356121756136417\n",
      "Epoch 11 Training Accuracy: 0.66796875\n",
      "Epoch 12/15\n",
      "Time elapsed: 7.49 min\n",
      "Epoch 12 Training Loss: 0.6410062834620476\n",
      "Epoch 12 Training Accuracy: 0.66796875\n",
      "Epoch 13/15\n",
      "Time elapsed: 7.80 min\n",
      "Epoch 13 Training Loss: 0.6357512559741736\n",
      "Epoch 13 Training Accuracy: 0.66796875\n",
      "Epoch 14/15\n",
      "Time elapsed: 8.40 min\n",
      "Epoch 14 Training Loss: 0.641251940280199\n",
      "Epoch 14 Training Accuracy: 0.66796875\n",
      "Epoch 15/15\n",
      "Time elapsed: 9.05 min\n",
      "Epoch 15 Training Loss: 0.6358055237680673\n",
      "Epoch 15 Training Accuracy: 0.66796875\n",
      "Test accuracy DoRA finetune: 0.66\n"
     ]
    }
   ],
   "source": [
    "num_exp = 3\n",
    "num_epochs = 15\n",
    "learning_rate =5e-4\n",
    "\n",
    "output_dora_list = []\n",
    "\n",
    "checkpoint = \"google-bert/bert-large-cased\"\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "model_dora = copy.deepcopy(model) # AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "peft_config_dora = LoraConfig(use_dora=True, task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n",
    "model_dora = get_peft_model(model_dora, peft_config_dora)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "args = {\"gamma\":15,\n",
    "       \"lamba\": 1e-3,\n",
    "       \"inner_lr\":5e-4,\n",
    "       \"learning_rate\":5e-4,\n",
    "       \"tokenized_train_dataset\": tokenized_datasets[\"train\"],\n",
    "       \"tokenized_test_dataset\": tokenized_datasets[\"test\"],\n",
    "       \"prop_train\": 0.66,\n",
    "       \"max_time\": 3600,\n",
    "       \"num_inner_step\":1\n",
    "       }\n",
    "\n",
    "for i in range(num_exp):\n",
    "    set_seed(i)\n",
    "    print(\"random_seed:\",i)\n",
    "    print(\"========Iteration:\",i,\"=========\")\n",
    "    model_dora_=copy.deepcopy(model_dora)\n",
    "    model_dora_.to(DEVICE)\n",
    "    optimizer = torch.optim.Adam(model_dora_.parameters(), lr=learning_rate)\n",
    "    output_dora = train(num_epochs, model_dora_, optimizer, train_loader, test_loader, args = args)\n",
    "    output_dora_list.append(output_dora)\n",
    "    print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora_, test_loader, DEVICE):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "351c6003-9abe-46e4-a766-31f8fa5ffd26",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "with open(output_dir+\"output_dora_list.txt\", \"w\") as f:\n",
    "    f.write(str(output_dora_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b0ca469e-fd9d-4406-8dab-772959674dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/model/\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "torch.save(model_dora_.state_dict(), output_dir+\"model_dora.pth\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25fa024c-2798-4768-bd8b-96befdfa8ed2",
   "metadata": {},
   "source": [
    "### BiDoRa PBGD & PBGD_Free"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ff7870d-2849-4460-8e56-acd0f1bd20b5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-large-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting seed: 0\n",
      "random_seed: 0\n",
      "========Iteration: 0 =========\n"
     ]
    }
   ],
   "source": [
    "output_bidora_PBGD_Free_list = []\n",
    "# num_exp = 10\n",
    "# num_epochs = 10\n",
    "# learning_rate =5e-4\n",
    "\n",
    "checkpoint = \"google-bert/bert-large-cased\"\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "model_dora = copy.deepcopy(model) # AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "peft_config_dora = LoraConfig(use_dora=True, task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n",
    "model_dora = get_peft_model(model_dora, peft_config_dora)\n",
    "\n",
    "# args = {\"gamma\":15,\n",
    "#        \"lamba\": 5e-4,\n",
    "#        \"inner_lr\":5e-4,\n",
    "#        \"learning_rate\":5e-4,\n",
    "#        \"tokenized_train_dataset\": tokenized_datasets[\"train\"],\n",
    "#        \"tokenized_test_dataset\": tokenized_datasets[\"test\"],\n",
    "#        \"prop_train\": 0.66}\n",
    "\n",
    "for i in range(num_exp):\n",
    "    set_seed(i)\n",
    "    print(\"random_seed:\",i)\n",
    "    print(\"========Iteration:\",i,\"=========\")\n",
    "    model_dora_=copy.deepcopy(model_dora)\n",
    "    model_dora_.to(DEVICE)\n",
    "    output_dora = train_bidora_penalty(num_epochs, model_dora_,args=args,penalty_term=False)\n",
    "    output_bidora_PBGD_Free_list.append(output_dora)\n",
    "    print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora_, test_loader, DEVICE):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "563e689a-debe-43ff-8b41-d3008c28dbc3",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/model/\"\n",
    "torch.save(model_dora_.state_dict(), output_dir+\"model_bidora_PBGD_Free.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ef2266e-46c4-4550-8525-84d0689f041b",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/\"\n",
    "with open(output_dir+\"output_bidora_PBGD_Free_list.txt\", \"w\") as f:\n",
    "    f.write(str(output_bidora_PBGD_Free_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2dfdebb-e285-45df-ae4e-f4af802c9c61",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "output_bidora_list = []\n",
    "\n",
    "for i in range(num_exp):\n",
    "    set_seed(i)\n",
    "    print(\"random_seed:\",i)\n",
    "    print(\"========Iteration:\",i,\"=========\")\n",
    "    model_dora_=copy.deepcopy(model_dora)\n",
    "    model_dora_.to(DEVICE)\n",
    "    output_dora = train_bidora_penalty(num_epochs, model_dora_,args=args,penalty_term=True)\n",
    "    output_bidora_list.append(output_dora)\n",
    "    print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora_, test_loader, DEVICE):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9aeabc48-4657-4a0d-8441-5f88cc6753b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/model/\"\n",
    "torch.save(model_dora_.state_dict(), output_dir+\"model_bidora.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5e93017-e622-4f79-91c8-9af277b60e84",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/\"\n",
    "with open(output_dir+\"output_bidora_list.txt\", \"w\") as f:\n",
    "    f.write(str(output_bidora_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3144d1d5-5e00-4119-99cc-c422a3b6ac4f",
   "metadata": {},
   "source": [
    "### BiDoRa Origin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d63d568-2557-4ffe-8c75-ad278563039a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "checkpoint = \"google-bert/bert-large-cased\"\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "model_dora = copy.deepcopy(model) # AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "peft_config_dora = LoraConfig(use_dora=True, task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)\n",
    "model_dora = get_peft_model(model_dora, peft_config_dora)\n",
    "# args = {\"gamma\":15,\n",
    "#        \"lamba\": 1e-3,\n",
    "#        \"inner_lr\":5e-4,\n",
    "#        \"learning_rate\":5e-4,\n",
    "#        \"tokenized_train_dataset\": tokenized_datasets[\"train\"].select(range(320)),\n",
    "#        \"tokenized_test_dataset\": tokenized_datasets[\"test\"].select(range(160)),\n",
    "#        \"prop_train\": 0.66}\n",
    "\n",
    "\n",
    "# output_bidora_appx_list = []\n",
    "# num_exp = 3 \n",
    "# num_epochs =10\n",
    "for i in range(num_exp):\n",
    "    set_seed(i)\n",
    "    print(\"random_seed:\",i)\n",
    "    print(\"========Iteration:\",i,\"=========\")\n",
    "    model_dora_=copy.deepcopy(model_dora)\n",
    "    model_dora_.to(DEVICE)\n",
    "    output_dora = train_bidora_approx(num_epochs, model_dora_,args= args)\n",
    "    output_bidora_appx_list.append(output_dora)\n",
    "    print(f'Test accuracy DoRA finetune: {compute_accuracy(model_dora_, test_loader, DEVICE):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c20ee63a-233c-4da8-92a4-91489a68f932",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/model/\"\n",
    "torch.save(model_dora_.state_dict(), output_dir+\"model_bidora_origin.pth\")\n",
    "output_dir = \"output/\"\n",
    "with open(output_dir+\"output_bidora_origin_list.txt\", \"w\") as f:\n",
    "    f.write(str(output_bidora_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcebdeb9-965c-4075-aee1-d828ede4da79",
   "metadata": {},
   "outputs": [],
   "source": [
    "# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
    "output = train(num_epochs, model, optimizer, train_loader, DEVICE)\n",
    "print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "282e632a-155c-458c-84ec-0a839e21489b",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/\"\n",
    "with open(output_dir+\"output.txt\", \"w\") as f:\n",
    "    f.write(str(output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91c98f15-8914-4d30-aaa5-21a0302c058f",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/\"\n",
    "with open(output_dir+\"output_lora.txt\", \"w\") as f:\n",
    "    f.write(str(output_lora))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4b026025-2708-41fc-afd8-0473b5a410e1",
   "metadata": {},
   "source": [
    "# Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c669e29e-0e1f-4aef-a8c7-0128d7437c92",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dir = \"output/\"\n",
    "with open(output_dir+\"output_dora_list.txt\", \"r\") as f:\n",
    "    output_dora_list = eval(f.read())\n",
    "    \n",
    "with open(output_dir+\"output_bidora_PBGD_Free_list.txt\", \"r\") as f:\n",
    "    output_bidora_PBGD_Free_list = eval(f.read())\n",
    "\n",
    "with open(output_dir+\"output_bidora_list.txt\", \"r\") as f:\n",
    "    output_bidora_list = eval(f.read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d60ca193-6be3-430a-b448-21541b6f25fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_and_pad(output_lists):\n",
    "    acc_lists = []\n",
    "    max_len = max(len(lst) for lst in output_lists)\n",
    "\n",
    "    for lst in output_lists:\n",
    "        if len(lst) == max_len:\n",
    "            acc_lists.append(lst)\n",
    "        else:\n",
    "            lst += [lst[-1]] * (max_len - len(lst))  # pad with last value\n",
    "            acc_lists.append(lst)\n",
    "\n",
    "    return np.array(acc_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95d7c898-fc87-419b-9bed-6e3cc55342f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_acc_dora = []\n",
    "for i in range(len(output_dora_list)):\n",
    "    test_acc_dora.append(output_dora_list[i][1][0])\n",
    "test_acc_dora = extract_and_pad(test_acc_dora)\n",
    "\n",
    "test_acc_bidora = []\n",
    "for i in range(len(output_bidora_list)):\n",
    "    test_acc_bidora.append(output_bidora_list[i][1][0])\n",
    "test_acc_bidora = extract_and_pad(test_acc_bidora)\n",
    "\n",
    "test_acc_bidora_PBGD_Free = []\n",
    "for i in range(len(output_bidora_PBGD_Free_list)):\n",
    "    test_acc_bidora_PBGD_Free.append(output_bidora_PBGD_Free_list[i][1][0])\n",
    "test_acc_bidora_PBGD_Free = extract_and_pad(test_acc_bidora_PBGD_Free)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7acc89aa-a72f-4492-ad38-8d4facb75696",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(test_acc_bidora_PBGD_Free.mean(0)[-1],test_acc_bidora_PBGD_Free.std(0)[-1])\n",
    "print(test_acc_bidora.mean(0)[-1],test_acc_bidora.std(0)[-1])\n",
    "print(test_acc_dora.mean(0)[-1],test_acc_dora.std(0)[-1])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fdc9822-5c2c-415a-933f-765db74b9f58",
   "metadata": {},
   "source": [
    "# Movie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62897a2c-fe3b-476c-bcca-e1f1b92c64f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_datasets = load_dataset(\"imdb\") \n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\", padding=\"max_length\", truncation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ade470fa-d1f8-4756-afb8-225c41913155",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_function(examples):\n",
    "    return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
    "\n",
    "tokenized_datasets = tokenized_datasets.remove_columns([\"text\"])\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
    "tokenized_datasets = tokenized_datasets.with_format(\"torch\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9edf6dbd-0e01-4f85-8c86-e36c6f01971d",
   "metadata": {},
   "outputs": [],
   "source": [
    "small_train_dataset = tokenized_datasets[\"train\"].shuffle(seed=42).select(range(2000))\n",
    "small_eval_dataset = tokenized_datasets[\"test\"].shuffle(seed=42).select(range(1000))\n",
    "full_train_dataset = tokenized_datasets[\"train\"]\n",
    "full_eval_dataset = tokenized_datasets[\"test\"]\n",
    "\n",
    "train_loader = DataLoader(dataset= small_train_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True)\n",
    "\n",
    "\n",
    "# Create DataLoader for the test dataset\n",
    "test_loader = DataLoader(dataset= small_eval_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d1fcf77-0fd5-4397-b2e4-40b8e878f4e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pretrained model\n",
    "checkpoint = \"google-bert/bert-large-cased\"\n",
    "model = AutoModelForSequenceClassification.from_pretrained(checkpoint,num_labels = 2)\n",
    "\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "288e54c2-382a-4e10-a391-071c41a3a4ad",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27512388-287c-4c9c-bcd6-a08d9f0d498d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
