{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a825ba6b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "===================================BUG REPORT===================================\n",
      "Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n",
      "For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link\n",
      "================================================================================\n",
      "CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so\n",
      "CUDA SETUP: Highest compute capability among GPUs detected: 7.5\n",
      "CUDA SETUP: Detected CUDA version 117\n",
      "CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from peft import (\n",
    "    get_peft_config,\n",
    "    get_peft_model,\n",
    "    get_peft_model_state_dict,\n",
    "    set_peft_model_state_dict,\n",
    "    PeftType,\n",
    "    PrefixTuningConfig,\n",
    "    PromptEncoderConfig,\n",
    ")\n",
    "\n",
    "import evaluate\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2bd7cbb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "model_name_or_path = \"roberta-large\"\n",
    "task = \"mrpc\"\n",
    "peft_type = PeftType.PREFIX_TUNING\n",
    "device = \"cuda\"\n",
    "num_epochs = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "33d9b62e",
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_config = PrefixTuningConfig(task_type=\"SEQ_CLS\", num_virtual_tokens=20)\n",
    "lr = 1e-2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "152b6177",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset glue (/home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be1eddbb9a7d4e6dae32fd026e167f96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9fa7887f9eaa03ae.arrow\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b61574844b6c499b8960fd4d78c5e549",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at /home/sourab/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7e7eacaa5160936d.arrow\n"
     ]
    }
   ],
   "source": [
    "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
    "    padding_side = \"left\"\n",
    "else:\n",
    "    padding_side = \"right\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
    "if getattr(tokenizer, \"pad_token_id\") is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "datasets = load_dataset(\"glue\", task)\n",
    "metric = evaluate.load(\"glue\", task)\n",
    "\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    # max_length=None => use the model max length (it's actually the default)\n",
    "    outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=None)\n",
    "    return outputs\n",
    "\n",
    "\n",
    "tokenized_datasets = datasets.map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
    ")\n",
    "\n",
    "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
    "# transformers library\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")\n",
    "\n",
    "\n",
    "def collate_fn(examples):\n",
    "    return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "# Instantiate dataloaders.\n",
    "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6bc8144",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "af41c571",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = AdamW(params=model.parameters(), lr=lr)\n",
    "\n",
    "# Instantiate scheduler\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "90993c93",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                  | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:29<00:00,  3.87it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.32it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0: {'accuracy': 0.7132352941176471, 'f1': 0.7876588021778584}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.42it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.36it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1: {'accuracy': 0.6838235294117647, 'f1': 0.8122270742358079}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.41it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.35it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2: {'accuracy': 0.8088235294117647, 'f1': 0.8717105263157895}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3: {'accuracy': 0.7549019607843137, 'f1': 0.8475609756097561}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:26<00:00,  4.37it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  8.34it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8938356164383561}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [00:40<00:00,  2.87it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00,  1.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5: {'accuracy': 0.8651960784313726, 'f1': 0.9053356282271946}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:53<00:00,  1.01it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:07<00:00,  1.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6: {'accuracy': 0.8700980392156863, 'f1': 0.9065255731922399}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:42<00:00,  1.12it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7: {'accuracy': 0.8676470588235294, 'f1': 0.9042553191489361}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.45it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8: {'accuracy': 0.875, 'f1': 0.9103690685413005}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:29<00:00,  1.29it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.48it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9: {'accuracy': 0.8799019607843137, 'f1': 0.913884007029877}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:43<00:00,  1.11it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00,  1.88it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10: {'accuracy': 0.8725490196078431, 'f1': 0.902621722846442}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:53<00:00,  1.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00,  2.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11: {'accuracy': 0.875, 'f1': 0.9090909090909091}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:29<00:00,  1.28it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12: {'accuracy': 0.8823529411764706, 'f1': 0.9139784946236559}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13: {'accuracy': 0.8602941176470589, 'f1': 0.9018932874354562}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.47it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14: {'accuracy': 0.8700980392156863, 'f1': 0.9075043630017452}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15: {'accuracy': 0.875, 'f1': 0.9087656529516995}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16: {'accuracy': 0.8578431372549019, 'f1': 0.9003436426116839}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17: {'accuracy': 0.8627450980392157, 'f1': 0.903448275862069}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:28<00:00,  1.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18: {'accuracy': 0.8700980392156863, 'f1': 0.9078260869565218}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████| 115/115 [01:27<00:00,  1.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19: {'accuracy': 0.8774509803921569, 'f1': 0.9125874125874125}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.to(device)\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch.to(device)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch.to(device)\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        predictions = outputs.logits.argmax(dim=-1)\n",
    "        predictions, references = predictions, batch[\"labels\"]\n",
    "        metric.add_batch(\n",
    "            predictions=predictions,\n",
    "            references=references,\n",
    "        )\n",
    "\n",
    "    eval_metric = metric.compute()\n",
    "    print(f\"epoch {epoch}:\", eval_metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7734299c",
   "metadata": {},
   "source": [
    "## Share adapters on the 🤗 Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "afaf42dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/smangrul/roberta-large-peft-prefix-tuning/commit/a00e05a4c9a68e700221784f8e073c2e194637c3', commit_message='Upload model', commit_description='', oid='a00e05a4c9a68e700221784f8e073c2e194637c3', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub(\"smangrul/roberta-large-peft-prefix-tuning\", use_auth_token=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42b20e77",
   "metadata": {},
   "source": [
    "## Load adapters from the Hub\n",
    "\n",
    "You can also directly load adapters from the Hub using the commands below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "868e7580",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ce57b4de8ae4f868115733abc2fb883",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/373 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at roberta-large were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'roberta.pooler.dense.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.bias', 'classifier.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ace158c926a44b31a9b0ea80411bd7a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading:   0%|          | 0.00/8.14M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                                                                                   | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:06<00:00,  2.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.8774509803921569, 'f1': 0.9125874125874125}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "peft_model_id = \"smangrul/roberta-large-peft-prefix-tuning\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
    "\n",
    "# Load the Lora model\n",
    "inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)\n",
    "\n",
    "inference_model.to(device)\n",
    "inference_model.eval()\n",
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    batch.to(device)\n",
    "    with torch.no_grad():\n",
    "        outputs = inference_model(**batch)\n",
    "    predictions = outputs.logits.argmax(dim=-1)\n",
    "    predictions, references = predictions, batch[\"labels\"]\n",
    "    metric.add_batch(\n",
    "        predictions=predictions,\n",
    "        references=references,\n",
    "    )\n",
    "\n",
    "eval_metric = metric.compute()\n",
    "print(eval_metric)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5 (v3.10.5:f377153967, Jun  6 2022, 12:36:10) [Clang 13.0.0 (clang-1300.0.29.30)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
