{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d36e1e93-ae93-4a4e-93c6-68fd868d2882",
   "metadata": {},
   "source": [
    "# Using VB-LoRA for sequence classification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddfc0610-55f6-4343-a950-125ccf0f45ac",
   "metadata": {},
   "source": [
    "In this example, we fine-tune Roberta on a sequence classification task using VB-LoRA.\n",
    "\n",
    "This notebook is adapted from `examples/sequence_classification/VeRA.ipynb`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45addd81-d4f3-4dfd-960d-3920d347f0a6",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9935ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from peft import (\n",
    "    get_peft_model,\n",
    "    VBLoRAConfig,\n",
    "    PeftType,\n",
    ")\n",
    "\n",
    "import evaluate\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62c959bf-7cc2-49e0-b97e-4c10ec3b9bf3",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e3b13308",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f4fc7c3c750>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_size = 32\n",
    "model_name_or_path = \"roberta-large\"\n",
    "task = \"mrpc\"\n",
    "peft_type = PeftType.VBLORA\n",
    "device = \"cuda\"\n",
    "num_epochs = 20\n",
    "rank = 4\n",
    "max_length = 128\n",
    "num_vectors = 90\n",
    "vector_length = 256\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0526f571",
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_config = VBLoRAConfig(\n",
    "    task_type=\"SEQ_CLS\", \n",
    "    r=rank,\n",
    "    topk=2,\n",
    "    target_modules=['key', 'value', 'query', 'output.dense', 'intermediate.dense'],\n",
    "    num_vectors=num_vectors,\n",
    "    vector_length=vector_length,\n",
    "    save_only_topk_weights=True, # Set to True to reduce storage space. Note that the saved parameters cannot be used to resume training from checkpoints.\n",
    "    vblora_dropout=0.,\n",
    ")\n",
    "head_lr = 4e-3\n",
    "vector_bank_lr = 1e-3\n",
    "logits_lr = 1e-2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c075c5d2-a457-4f37-a7f1-94fd0d277972",
   "metadata": {},
   "source": [
    "## Loading data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7bb52cb4-d1c3-4b04-8bf0-f39ca88af139",
   "metadata": {},
   "outputs": [],
   "source": [
    "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
    "    padding_side = \"left\"\n",
    "else:\n",
    "    padding_side = \"right\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
    "if getattr(tokenizer, \"pad_token_id\") is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e69c5e1f-d27b-4264-a41e-fc9b99d025e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = load_dataset(\"glue\", task)\n",
    "metric = evaluate.load(\"glue\", task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0209f778-c93b-40eb-a4e0-24c25db03980",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_function(examples):\n",
    "    # max_length=None => use the model max length (it's actually the default)\n",
    "    outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=max_length)\n",
    "    return outputs\n",
    "\n",
    "\n",
    "tokenized_datasets = datasets.map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
    ")\n",
    "\n",
    "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
    "# transformers library\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7453954e-982c-46f0-b09c-589776e6d6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(examples):\n",
    "    return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "# Instantiate dataloaders.\n",
    "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3b9b2e8-f415-4d0f-9fb4-436f1a3585ea",
   "metadata": {},
   "source": [
    "## Preparing the VB-LoRA model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2ed5ac74",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1,696,770 || all params: 357,058,564 || trainable%: 0.4752\n",
      "VB-LoRA params to-be-saved (float32-equivalent): 33,408 || total params to-be-saved: 1,085,058\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()\n",
    "model.print_savable_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0d2d0381",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS\n",
    "from transformers.trainer_pt_utils import get_parameter_names\n",
    "\n",
    "decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS)\n",
    "decay_parameters = [name for name in decay_parameters if \"bias\" not in name]\n",
    "vector_bank_parameters = [name for name, _ in model.named_parameters() if \"vector_bank\" in name]\n",
    "logits_parameters = [name for name, _ in model.named_parameters() if \"logits\" in name ]\n",
    "\n",
    "optimizer_grouped_parameters = [\n",
    "    {\n",
    "        \"params\": [p for n, p in model.named_parameters() if n in decay_parameters and \\\n",
    "                    n not in logits_parameters and n not in vector_bank_parameters],\n",
    "        \"weight_decay\": 0.1,\n",
    "        \"lr\": head_lr,\n",
    "    },\n",
    "    {\n",
    "        \"params\": [p for n, p in model.named_parameters() if n not in decay_parameters and \\\n",
    "                    n not in logits_parameters and n not in vector_bank_parameters],\n",
    "        \"weight_decay\": 0.0,\n",
    "        \"lr\": head_lr,\n",
    "    },\n",
    "    {\n",
    "        \"params\": [p for n, p in model.named_parameters() if n in vector_bank_parameters],\n",
    "        \"lr\": vector_bank_lr,\n",
    "        \"weight_decay\": 0.0,\n",
    "    },\n",
    "    {\n",
    "        \"params\": [p for n, p in model.named_parameters() if n in logits_parameters],\n",
    "        \"lr\": logits_lr,\n",
    "        \"weight_decay\": 0.0,\n",
    "    },\n",
    "]\n",
    "\n",
    "optimizer = AdamW(optimizer_grouped_parameters)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0dd5aa8-977b-4ac0-8b96-884b17bcdd00",
   "metadata": {},
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fa0e73be",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/115 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 0: {'accuracy': 0.6691176470588235, 'f1': 0.786053882725832}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.37it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.83it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1: {'accuracy': 0.5833333333333334, 'f1': 0.6136363636363636}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.34it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2: {'accuracy': 0.7107843137254902, 'f1': 0.8238805970149253}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.34it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3: {'accuracy': 0.8284313725490197, 'f1': 0.8833333333333333}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.34it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4: {'accuracy': 0.8480392156862745, 'f1': 0.8847583643122676}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.30it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5: {'accuracy': 0.8676470588235294, 'f1': 0.898876404494382}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.31it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6: {'accuracy': 0.8602941176470589, 'f1': 0.9035532994923858}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.32it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7: {'accuracy': 0.8774509803921569, 'f1': 0.911660777385159}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8: {'accuracy': 0.8872549019607843, 'f1': 0.9172661870503597}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.32it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9: {'accuracy': 0.875, 'f1': 0.9113043478260869}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.32it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10: {'accuracy': 0.8823529411764706, 'f1': 0.9166666666666666}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11: {'accuracy': 0.8970588235294118, 'f1': 0.9252669039145908}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.32it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12: {'accuracy': 0.8946078431372549, 'f1': 0.9246935201401051}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13: {'accuracy': 0.9068627450980392, 'f1': 0.9316546762589928}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14: {'accuracy': 0.8946078431372549, 'f1': 0.9225225225225225}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15: {'accuracy': 0.8995098039215687, 'f1': 0.926391382405745}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.30it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16: {'accuracy': 0.9068627450980392, 'f1': 0.9316546762589928}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.31it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17: {'accuracy': 0.8921568627450981, 'f1': 0.9217081850533808}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18: {'accuracy': 0.8995098039215687, 'f1': 0.9266547406082289}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 115/115 [00:34<00:00,  3.33it/s]\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19: {'accuracy': 0.9044117647058824, 'f1': 0.9297297297297298}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch.to(device)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch.to(device)\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        predictions = outputs.logits.argmax(dim=-1)\n",
    "        predictions, references = predictions, batch[\"labels\"]\n",
    "        metric.add_batch(\n",
    "            predictions=predictions,\n",
    "            references=references,\n",
    "        )\n",
    "\n",
    "    eval_metric = metric.compute()\n",
    "    print(f\"epoch {epoch}:\", eval_metric)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2b2caca",
   "metadata": {},
   "source": [
    "## Share adapters on the 🤗 Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7b23af6f-cf6e-486f-9d10-0eada95b631f",
   "metadata": {},
   "outputs": [],
   "source": [
    "account_id = ...  # your Hugging Face Hub account ID"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "990b3c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub(f\"{account_id}/roberta-large-peft-vblora\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d140b26",
   "metadata": {},
   "source": [
    "## Load adapters from the Hub\n",
    "\n",
    "You can also directly load adapters from the Hub using the commands below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c283e028-b349-46b0-a20e-cde0ee5fbd7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "320b10a0-4ea8-4786-9f3c-4670019c6b18",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "peft_model_id = f\"{account_id}/roberta-large-peft-vblora\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b3a94049-bc01-4f2e-8cf9-66daf24a4402",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model\n",
    "inference_model = PeftModel.from_pretrained(inference_model, peft_model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bd919fef-4e9a-4dc5-a957-7b879cfc5d38",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/13 [00:00<?, ?it/s]You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n",
      "100%|██████████| 13/13 [00:01<00:00,  7.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.9044117647058824, 'f1': 0.9297297297297298}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "inference_model.to(device)\n",
    "inference_model.eval()\n",
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    batch.to(device)\n",
    "    with torch.no_grad():\n",
    "        outputs = inference_model(**batch)\n",
    "    predictions = outputs.logits.argmax(dim=-1)\n",
    "    predictions, references = predictions, batch[\"labels\"]\n",
    "    metric.add_batch(\n",
    "        predictions=predictions,\n",
    "        references=references,\n",
    "    )\n",
    "\n",
    "eval_metric = metric.compute()\n",
    "print(eval_metric)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
