{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "raw_datasets = load_dataset('json', data_files='https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/hellaswag/train.json')\n",
    "\n",
    "val_datasets  = load_dataset('json', data_files='https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/hellaswag/test.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7756a810",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['instruction', 'input', 'output', 'answer'],\n",
       "        num_rows: 10042\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "769a952a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "\n",
    "\n",
    "# Initialize and fit label encoder\n",
    "label_encoder = LabelEncoder()\n",
    "label_encoder.fit(raw_datasets['train']['answer'])\n",
    "# Create the integer labels\n",
    "train_labels = label_encoder.transform(raw_datasets['train']['answer'])\n",
    "\n",
    "# Add a new 'labels' column\n",
    "raw_datasets['train'] = raw_datasets['train'].add_column('labels', train_labels)\n",
    "\n",
    "val_labels = label_encoder.transform(val_datasets['train']['answer'])\n",
    "\n",
    "# Add a new 'labels' column\n",
    "val_datasets['train'] = val_datasets['train'].add_column('labels', val_labels)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9badafa3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.int64(3)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max(train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:559: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "model_name = \"microsoft/deberta-v3-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 39905/39905 [00:02<00:00, 15816.55 examples/s]\n",
      "Map: 100%|██████████| 10042/10042 [00:00<00:00, 16635.95 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
      "    num_rows: 39905\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
      "    num_rows: 10042\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    # sorry about the formatting disaster gotta move fast\n",
    "    return f\"\"\"# input: {data_point[\"instruction\"].split('format:')[0]}:\"\"\"\n",
    "               \n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = example['labels']\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = val_datasets['train'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# input: Please choose the correct ending to complete the given sentence: Removing ice from car: Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then\\n\\nEnding1: , the man adds wax to the windshield and cuts it. Ending2: , a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled. Ending3: , the man puts on a christmas coat, knitted with netting. Ending4: , the man continues removing the snow on his car.\\n\\nAnswer :'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data['input'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 39905/39905 [00:08<00:00, 4985.37 examples/s]\n",
      "Map: 100%|██████████| 10042/10042 [00:02<00:00, 4756.55 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['instruction', 'input', 'output', 'answer']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] # input: Please choose the correct ending to complete the given sentence: Getting a haircut: He also trims the back and sides of his head with the clippers. He uses scissors to trim the hair and give it a finished look. the model Ending1: poses with how to complete the look using the partial and then suggests a computer graphic tool. Ending2: poses and begins to talk while holding up a piece of metal. Ending3: uses some hair gel to style his hair after the haircut. Ending4: uses rollers to take some of the hair off the brush and enjoy the ending. Answer :[SEP]'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
       "    num_rows: 10042\n",
       "})"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "355"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f1005af8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total labels (classes): 4\n"
     ]
    }
   ],
   "source": [
    "num_labels = len(label_encoder.classes_)\n",
    "print(f\"Total labels (classes): {num_labels}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "# from modeling import MLMSequenceClassification\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "config.num_labels = num_labels\n",
    "config.mask_token_id = tokenizer.mask_token_id\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import evaluate\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "\n",
    "\n",
    "    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    \n",
    "    precision = metrics.precision_score(labels, predictions, average=\"macro\")\n",
    "    recall = metrics.recall_score(labels, predictions, average=\"macro\")\n",
    "    f1 = metrics.f1_score(labels, predictions, average=\"macro\")\n",
    "    accuracy = metrics.accuracy_score(labels, predictions)\n",
    "    \n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1-score\": f1, 'accuracy': accuracy}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=1e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=12,\n",
    "    weight_decay=0.0,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=1000,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='29940' max='29940' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [29940/29940 3:07:44, Epoch 12/12]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>1.393100</td>\n",
       "      <td>1.387998</td>\n",
       "      <td>0.246779</td>\n",
       "      <td>0.250973</td>\n",
       "      <td>0.222988</td>\n",
       "      <td>0.252539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>1.389500</td>\n",
       "      <td>1.385791</td>\n",
       "      <td>0.265868</td>\n",
       "      <td>0.255138</td>\n",
       "      <td>0.182253</td>\n",
       "      <td>0.260008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>1.388300</td>\n",
       "      <td>1.386143</td>\n",
       "      <td>0.254766</td>\n",
       "      <td>0.251217</td>\n",
       "      <td>0.209040</td>\n",
       "      <td>0.251544</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>1.388300</td>\n",
       "      <td>1.386205</td>\n",
       "      <td>0.255199</td>\n",
       "      <td>0.254933</td>\n",
       "      <td>0.201133</td>\n",
       "      <td>0.251942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>1.387800</td>\n",
       "      <td>1.386512</td>\n",
       "      <td>0.190450</td>\n",
       "      <td>0.254821</td>\n",
       "      <td>0.202187</td>\n",
       "      <td>0.253436</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>1.387800</td>\n",
       "      <td>1.386499</td>\n",
       "      <td>0.200626</td>\n",
       "      <td>0.254076</td>\n",
       "      <td>0.142255</td>\n",
       "      <td>0.251842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>1.388100</td>\n",
       "      <td>1.386565</td>\n",
       "      <td>0.230340</td>\n",
       "      <td>0.255056</td>\n",
       "      <td>0.190018</td>\n",
       "      <td>0.252141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8000</td>\n",
       "      <td>1.387700</td>\n",
       "      <td>1.386425</td>\n",
       "      <td>0.240418</td>\n",
       "      <td>0.253798</td>\n",
       "      <td>0.208060</td>\n",
       "      <td>0.256224</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9000</td>\n",
       "      <td>1.387100</td>\n",
       "      <td>1.386585</td>\n",
       "      <td>0.257397</td>\n",
       "      <td>0.254581</td>\n",
       "      <td>0.216202</td>\n",
       "      <td>0.252639</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10000</td>\n",
       "      <td>1.388500</td>\n",
       "      <td>1.386840</td>\n",
       "      <td>0.248019</td>\n",
       "      <td>0.258219</td>\n",
       "      <td>0.206599</td>\n",
       "      <td>0.255826</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11000</td>\n",
       "      <td>1.387800</td>\n",
       "      <td>1.386571</td>\n",
       "      <td>0.188335</td>\n",
       "      <td>0.251688</td>\n",
       "      <td>0.182927</td>\n",
       "      <td>0.250647</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12000</td>\n",
       "      <td>1.387500</td>\n",
       "      <td>1.386181</td>\n",
       "      <td>0.242812</td>\n",
       "      <td>0.255216</td>\n",
       "      <td>0.188257</td>\n",
       "      <td>0.252041</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13000</td>\n",
       "      <td>1.387800</td>\n",
       "      <td>1.386371</td>\n",
       "      <td>0.180084</td>\n",
       "      <td>0.252905</td>\n",
       "      <td>0.155485</td>\n",
       "      <td>0.250050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14000</td>\n",
       "      <td>1.387200</td>\n",
       "      <td>1.386254</td>\n",
       "      <td>0.269908</td>\n",
       "      <td>0.260221</td>\n",
       "      <td>0.184381</td>\n",
       "      <td>0.255925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15000</td>\n",
       "      <td>1.388200</td>\n",
       "      <td>1.386304</td>\n",
       "      <td>0.189629</td>\n",
       "      <td>0.255605</td>\n",
       "      <td>0.213845</td>\n",
       "      <td>0.252539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16000</td>\n",
       "      <td>1.387700</td>\n",
       "      <td>1.386276</td>\n",
       "      <td>0.314453</td>\n",
       "      <td>0.257390</td>\n",
       "      <td>0.178801</td>\n",
       "      <td>0.253236</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17000</td>\n",
       "      <td>1.387700</td>\n",
       "      <td>1.386240</td>\n",
       "      <td>0.439420</td>\n",
       "      <td>0.254023</td>\n",
       "      <td>0.209438</td>\n",
       "      <td>0.252440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18000</td>\n",
       "      <td>1.387300</td>\n",
       "      <td>1.386255</td>\n",
       "      <td>0.286892</td>\n",
       "      <td>0.256392</td>\n",
       "      <td>0.170510</td>\n",
       "      <td>0.252539</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19000</td>\n",
       "      <td>1.387300</td>\n",
       "      <td>1.386159</td>\n",
       "      <td>0.251170</td>\n",
       "      <td>0.259617</td>\n",
       "      <td>0.216636</td>\n",
       "      <td>0.257518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20000</td>\n",
       "      <td>1.387400</td>\n",
       "      <td>1.386262</td>\n",
       "      <td>0.187751</td>\n",
       "      <td>0.252473</td>\n",
       "      <td>0.206750</td>\n",
       "      <td>0.250846</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21000</td>\n",
       "      <td>1.387200</td>\n",
       "      <td>1.386181</td>\n",
       "      <td>0.256159</td>\n",
       "      <td>0.261309</td>\n",
       "      <td>0.206088</td>\n",
       "      <td>0.258415</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22000</td>\n",
       "      <td>1.387800</td>\n",
       "      <td>1.386206</td>\n",
       "      <td>0.275465</td>\n",
       "      <td>0.255831</td>\n",
       "      <td>0.207302</td>\n",
       "      <td>0.254431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23000</td>\n",
       "      <td>1.387200</td>\n",
       "      <td>1.386279</td>\n",
       "      <td>0.191602</td>\n",
       "      <td>0.252976</td>\n",
       "      <td>0.185284</td>\n",
       "      <td>0.252241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24000</td>\n",
       "      <td>1.387000</td>\n",
       "      <td>1.386229</td>\n",
       "      <td>0.192912</td>\n",
       "      <td>0.254958</td>\n",
       "      <td>0.190123</td>\n",
       "      <td>0.254133</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25000</td>\n",
       "      <td>1.386800</td>\n",
       "      <td>1.386274</td>\n",
       "      <td>0.191432</td>\n",
       "      <td>0.257463</td>\n",
       "      <td>0.218193</td>\n",
       "      <td>0.255128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26000</td>\n",
       "      <td>1.387600</td>\n",
       "      <td>1.386240</td>\n",
       "      <td>0.192702</td>\n",
       "      <td>0.258595</td>\n",
       "      <td>0.207386</td>\n",
       "      <td>0.255726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27000</td>\n",
       "      <td>1.387000</td>\n",
       "      <td>1.386250</td>\n",
       "      <td>0.192040</td>\n",
       "      <td>0.258501</td>\n",
       "      <td>0.212576</td>\n",
       "      <td>0.255826</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28000</td>\n",
       "      <td>1.387400</td>\n",
       "      <td>1.386253</td>\n",
       "      <td>0.192349</td>\n",
       "      <td>0.258254</td>\n",
       "      <td>0.209861</td>\n",
       "      <td>0.255427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29000</td>\n",
       "      <td>1.387000</td>\n",
       "      <td>1.386248</td>\n",
       "      <td>0.191984</td>\n",
       "      <td>0.257913</td>\n",
       "      <td>0.210657</td>\n",
       "      <td>0.255128</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n",
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=29940, training_loss=1.3878104883269142, metrics={'train_runtime': 11265.6436, 'train_samples_per_second': 42.506, 'train_steps_per_second': 2.658, 'total_flos': 314436494548800.0, 'train_loss': 1.3878104883269142, 'epoch': 12.0})"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b8833b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
