{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fc4b1225-0320-4e36-ace5-39054fb6d42d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "from peft import LoraConfig, get_peft_model, PeftConfig, PeftModel, prepare_model_for_kbit_training, AdaLoraConfig\n",
    "\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset(\"glue\", 'rte')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2df9a044-29af-4e31-8094-f80424ca1e18",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:559: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "model_name = \"microsoft/deberta-v3-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "96945cf0-5205-4c66-a66e-30dc4ecfb95b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 2490/2490 [00:00<00:00, 8811.77 examples/s]\n",
      "Map: 100%|██████████| 277/277 [00:00<00:00, 9311.19 examples/s]\n",
      "Map: 100%|██████████| 3000/3000 [00:00<00:00, 10156.31 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete = ['sentence1','sentence2']\n",
    "\n",
    "def preprocessing_function(examples):\n",
    "    prompts = [\n",
    "        f\"Premise: {premise} Hypothesis: {hypothesis} \"\n",
    "        f\"Does the premise imply the hypothesis? Answer:{mask_token}\"\n",
    "        for premise, hypothesis in zip(examples[\"sentence1\"], examples[\"sentence2\"])\n",
    "    ]\n",
    "    return tokenizer(prompts, padding = False, truncation=True, max_length=512)\n",
    "\n",
    "tokenized_datasets = raw_datasets.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_datasets.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3ffcba0a-0c79-444c-b456-9553c0081c76",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] Premise: The international humanitarian aid organization, Doctors Without Borders/Medecins Sans Frontieres (MSF), continues to treat victims of violence in all locations where it is present in Darfur. Hypothesis: Doctors Without Borders is an international aid organization. Does the premise imply the hypothesis? Answer:[MASK][SEP]'"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_datasets['validation']['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "59055d75-85e1-46b6-9363-5c54536cbc13",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "128000"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.mask_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "15ce4d0c-d9b0-4c14-8afc-c2b0aa065517",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "# from modeling import MLMSequenceClassification\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "config.mask_token_id=tokenizer.mask_token_id\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "afe53da5-1b8a-4333-bb48-4e449c03f386",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "91862ad9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DebertaV2ForSequenceClassification(\n",
       "  (deberta): DebertaV2Model(\n",
       "    (embeddings): DebertaV2Embeddings(\n",
       "      (word_embeddings): Embedding(128100, 768, padding_idx=0)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "      (dropout): Dropout(p=0.1, inplace=False)\n",
       "    )\n",
       "    (encoder): DebertaV2Encoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x DebertaV2Layer(\n",
       "          (attention): DebertaV2Attention(\n",
       "            (self): DisentangledSelfAttention(\n",
       "              (query_proj): column()\n",
       "              (key_proj): column()\n",
       "              (value_proj): column()\n",
       "              (pos_dropout): Dropout(p=0.1, inplace=False)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "            (output): DebertaV2SelfOutput(\n",
       "              (dense): column()\n",
       "              (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "              (dropout): Dropout(p=0.1, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): DebertaV2Intermediate(\n",
       "            (dense): column()\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): DebertaV2Output(\n",
       "            (dense): column()\n",
       "            (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (rel_embeddings): Embedding(512, 768)\n",
       "      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)\n",
       "    )\n",
       "  )\n",
       "  (pooler): ContextPooler(\n",
       "    (dense): column()\n",
       "    (dropout): Dropout(p=0, inplace=False)\n",
       "  )\n",
       "  (classifier): column()\n",
       "  (dropout): Dropout(p=0.1, inplace=False)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "511c59e6-4783-4e60-92d1-4eff574a60b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "\n",
    "\n",
    "    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    \n",
    "    precision = metrics.precision_score(labels, predictions, average=\"macro\")\n",
    "    recall = metrics.recall_score(labels, predictions, average=\"macro\")\n",
    "    f1 = metrics.f1_score(labels, predictions, average=\"macro\")\n",
    "    accuracy = metrics.accuracy_score(labels, predictions)\n",
    "    \n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1-score\": f1, 'accuracy': accuracy}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5e6599de-44c7-4bc0-8b73-5bca6d73913e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=2e-3,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=20,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_datasets[\"train\"],\n",
    "    eval_dataset=tokenized_datasets[\"validation\"],\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3d9b3ed4-3cf2-4540-89dc-cde03dcd1dab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='3120' max='3120' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [3120/3120 09:53, Epoch 20/20]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.700500</td>\n",
       "      <td>0.682763</td>\n",
       "      <td>0.613458</td>\n",
       "      <td>0.533279</td>\n",
       "      <td>0.444673</td>\n",
       "      <td>0.555957</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.691900</td>\n",
       "      <td>0.687468</td>\n",
       "      <td>0.613423</td>\n",
       "      <td>0.529070</td>\n",
       "      <td>0.433201</td>\n",
       "      <td>0.552347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.696500</td>\n",
       "      <td>0.691412</td>\n",
       "      <td>0.263538</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.345154</td>\n",
       "      <td>0.527076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.696700</td>\n",
       "      <td>0.693924</td>\n",
       "      <td>0.236462</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.321078</td>\n",
       "      <td>0.472924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.696000</td>\n",
       "      <td>0.693107</td>\n",
       "      <td>0.764493</td>\n",
       "      <td>0.503817</td>\n",
       "      <td>0.353547</td>\n",
       "      <td>0.530686</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.694500</td>\n",
       "      <td>0.691806</td>\n",
       "      <td>0.263538</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.345154</td>\n",
       "      <td>0.527076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.694100</td>\n",
       "      <td>0.693257</td>\n",
       "      <td>0.236462</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.321078</td>\n",
       "      <td>0.472924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.696000</td>\n",
       "      <td>0.691512</td>\n",
       "      <td>0.263538</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.345154</td>\n",
       "      <td>0.527076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.694600</td>\n",
       "      <td>0.697006</td>\n",
       "      <td>0.236462</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.321078</td>\n",
       "      <td>0.472924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.695200</td>\n",
       "      <td>0.698315</td>\n",
       "      <td>0.386577</td>\n",
       "      <td>0.470930</td>\n",
       "      <td>0.339706</td>\n",
       "      <td>0.447653</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.693500</td>\n",
       "      <td>0.697395</td>\n",
       "      <td>0.488477</td>\n",
       "      <td>0.492445</td>\n",
       "      <td>0.436304</td>\n",
       "      <td>0.476534</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.697400</td>\n",
       "      <td>0.697299</td>\n",
       "      <td>0.467028</td>\n",
       "      <td>0.486484</td>\n",
       "      <td>0.387461</td>\n",
       "      <td>0.465704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.691400</td>\n",
       "      <td>0.689042</td>\n",
       "      <td>0.263538</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.345154</td>\n",
       "      <td>0.527076</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.677700</td>\n",
       "      <td>0.719051</td>\n",
       "      <td>0.628515</td>\n",
       "      <td>0.583159</td>\n",
       "      <td>0.532516</td>\n",
       "      <td>0.566787</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.614000</td>\n",
       "      <td>0.609063</td>\n",
       "      <td>0.697972</td>\n",
       "      <td>0.664331</td>\n",
       "      <td>0.656157</td>\n",
       "      <td>0.675090</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.573100</td>\n",
       "      <td>0.620925</td>\n",
       "      <td>0.759485</td>\n",
       "      <td>0.674501</td>\n",
       "      <td>0.655552</td>\n",
       "      <td>0.689531</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.536800</td>\n",
       "      <td>0.590465</td>\n",
       "      <td>0.704072</td>\n",
       "      <td>0.704669</td>\n",
       "      <td>0.703782</td>\n",
       "      <td>0.703971</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.476100</td>\n",
       "      <td>0.547032</td>\n",
       "      <td>0.749526</td>\n",
       "      <td>0.733818</td>\n",
       "      <td>0.733654</td>\n",
       "      <td>0.740072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.455900</td>\n",
       "      <td>0.677508</td>\n",
       "      <td>0.750838</td>\n",
       "      <td>0.703571</td>\n",
       "      <td>0.696512</td>\n",
       "      <td>0.714801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.423800</td>\n",
       "      <td>0.639312</td>\n",
       "      <td>0.739748</td>\n",
       "      <td>0.688304</td>\n",
       "      <td>0.678650</td>\n",
       "      <td>0.700361</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.349100</td>\n",
       "      <td>0.760036</td>\n",
       "      <td>0.759146</td>\n",
       "      <td>0.731857</td>\n",
       "      <td>0.730195</td>\n",
       "      <td>0.740072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.338700</td>\n",
       "      <td>0.707067</td>\n",
       "      <td>0.733394</td>\n",
       "      <td>0.701715</td>\n",
       "      <td>0.697466</td>\n",
       "      <td>0.711191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.277900</td>\n",
       "      <td>0.663429</td>\n",
       "      <td>0.756179</td>\n",
       "      <td>0.746052</td>\n",
       "      <td>0.746623</td>\n",
       "      <td>0.750903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.273100</td>\n",
       "      <td>0.699806</td>\n",
       "      <td>0.770049</td>\n",
       "      <td>0.760928</td>\n",
       "      <td>0.761754</td>\n",
       "      <td>0.765343</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.262200</td>\n",
       "      <td>0.736953</td>\n",
       "      <td>0.761550</td>\n",
       "      <td>0.731465</td>\n",
       "      <td>0.729404</td>\n",
       "      <td>0.740072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.208800</td>\n",
       "      <td>0.825694</td>\n",
       "      <td>0.762033</td>\n",
       "      <td>0.735674</td>\n",
       "      <td>0.734319</td>\n",
       "      <td>0.743682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.223200</td>\n",
       "      <td>0.777725</td>\n",
       "      <td>0.768008</td>\n",
       "      <td>0.752118</td>\n",
       "      <td>0.752431</td>\n",
       "      <td>0.758123</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.206700</td>\n",
       "      <td>0.797158</td>\n",
       "      <td>0.755932</td>\n",
       "      <td>0.736850</td>\n",
       "      <td>0.736413</td>\n",
       "      <td>0.743682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.211000</td>\n",
       "      <td>0.829470</td>\n",
       "      <td>0.759841</td>\n",
       "      <td>0.736066</td>\n",
       "      <td>0.735050</td>\n",
       "      <td>0.743682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.198400</td>\n",
       "      <td>0.838041</td>\n",
       "      <td>0.764922</td>\n",
       "      <td>0.739491</td>\n",
       "      <td>0.738425</td>\n",
       "      <td>0.747292</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.190200</td>\n",
       "      <td>0.825895</td>\n",
       "      <td>0.757810</td>\n",
       "      <td>0.736458</td>\n",
       "      <td>0.735747</td>\n",
       "      <td>0.743682</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=3120, training_loss=0.4992836372974591, metrics={'train_runtime': 593.9058, 'train_samples_per_second': 83.852, 'train_steps_per_second': 5.253, 'total_flos': 19094072276928.0, 'train_loss': 0.4992836372974591, 'epoch': 20.0})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "878c2bf0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e82fc391",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b498c89",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
