{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e27807b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "raw_datasets = load_dataset('json', data_files='https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/boolq/train.json')\n",
    "val_datasets  = load_dataset('json', data_files='https://raw.githubusercontent.com/AGI-Edgerunners/LLM-Adapters/refs/heads/main/dataset/boolq/test.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2c0d83c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification\n",
    "from huggingface_hub import login\n",
    "\n",
    "# Log in using your Hugging Face token\n",
    "login(\"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\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0c6e49f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "\n",
    "\n",
    "# Initialize and fit label encoder\n",
    "label_encoder = LabelEncoder()\n",
    "label_encoder.fit(raw_datasets['train']['answer'])\n",
    "# Create the integer labels\n",
    "train_labels = label_encoder.transform(raw_datasets['train']['answer'])\n",
    "\n",
    "# Add a new 'labels' column\n",
    "raw_datasets['train'] = raw_datasets['train'].add_column('labels', train_labels)\n",
    "\n",
    "val_labels = label_encoder.transform(val_datasets['train']['answer'])\n",
    "\n",
    "# Add a new 'labels' column\n",
    "val_datasets['train'] = val_datasets['train'].add_column('labels', val_labels)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "1453092f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"microsoft/deberta-v3-base\"\n",
    "model_name = \"answerdotai/ModernBERT-base\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8d0bb7c1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\\n#from roberta import RobertaForSequenceClassification\\nfrom modeling import CLMSequenceClassification\\n\\n\\n#model_name = \"openai-community/gpt2-medium\"\\nmodel_name = \"HuggingFaceTB/SmolLM2-135M\"\\n#config.num_labels=2\\ntokenizer = AutoTokenizer.from_pretrained(model_name)\\n\\ntokenizer.padding_side = \\'left\\'\\ntokenizer.pad_token = tokenizer.eos_token\\n\\nimport torch\\nimport torch.nn as nn\\nfrom transformers import AutoModelForSequenceClassification\\nfrom transformers.activations import ACT2FN\\nimport random\\n\\n\\n\\nmodel = CLMSequenceClassification.from_pretrained(model_name, num_labels=4).to(\\'cuda\\')\\nmodel.config.pad_token_id = tokenizer.eos_token_id\\nimport RoCoFT\\n\\nRoCoFT.PEFT(model, method=\\'row\\', rank=3) '"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "from modeling import CLMSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "model_name = \"HuggingFaceTB/SmolLM2-135M\"\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = CLMSequenceClassification.from_pretrained(model_name, num_labels=4).to('cuda')\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) '''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "12f358b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 9427/9427 [00:00<00:00, 17701.87 examples/s]\n",
      "Map: 100%|██████████| 3270/3270 [00:00<00:00, 19237.42 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
      "    num_rows: 9427\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['instruction', 'input', 'output', 'answer', 'labels'],\n",
      "    num_rows: 3270\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    # sorry about the formatting disaster gotta move fast\n",
    "    return f\"\"\"# input: {data_point[\"instruction\"].split('format:')[0]}:\"\"\"\n",
    "               \n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = example['labels']\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = val_datasets['train'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ebf4f744",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# input: Please answer the following question with true or false, question: do iran and afghanistan speak the same language?\\n\\nAnswer :'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data['input'][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3544cae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map:  74%|███████▍  | 7000/9427 [00:00<00:00, 27619.22 examples/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 9427/9427 [00:00<00:00, 27867.19 examples/s]\n",
      "Map: 100%|██████████| 3270/3270 [00:00<00:00, 31449.13 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['instruction', 'input', 'output', 'answer']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a70b43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS]# input: Please answer the following question with true or false, question: is a wolverine the same as a badger?\\n\\nAnswer :[SEP]'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2eea08f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "Total labels (classes): 2\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n",
    "num_labels = len(label_encoder.classes_)\n",
    "print(f\"Total labels (classes): {num_labels}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d64fb358",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "# from modeling import MLMSequenceClassification\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "config.num_labels = num_labels\n",
    "config.mask_token_id = tokenizer.mask_token_id\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f39dcdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import evaluate\n",
    "import numpy as np\n",
    "from sklearn import metrics\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "\n",
    "\n",
    "    logits, labels = eval_pred # eval_pred is the tuple of predictions and labels returned by the model\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    \n",
    "    precision = metrics.precision_score(labels, predictions, average=\"macro\")\n",
    "    recall = metrics.recall_score(labels, predictions, average=\"macro\")\n",
    "    f1 = metrics.f1_score(labels, predictions, average=\"macro\")\n",
    "    accuracy = metrics.accuracy_score(labels, predictions)\n",
    "    \n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1-score\": f1, 'accuracy': accuracy}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "b64eb4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=3e-4,\n",
    "    per_device_train_batch_size=128,\n",
    "    per_device_eval_batch_size=128,\n",
    "    num_train_epochs=100,\n",
    "    weight_decay=0.0,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    " \n",
    "\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5b333893",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/torch/_inductor/compile_fx.py:194: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7400' max='7400' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7400/7400 41:27, Epoch 100/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1-score</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.680000</td>\n",
       "      <td>0.689846</td>\n",
       "      <td>0.496045</td>\n",
       "      <td>0.498491</td>\n",
       "      <td>0.444705</td>\n",
       "      <td>0.596024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.691000</td>\n",
       "      <td>0.664213</td>\n",
       "      <td>0.310856</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.383368</td>\n",
       "      <td>0.621713</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.680600</td>\n",
       "      <td>0.694311</td>\n",
       "      <td>0.518455</td>\n",
       "      <td>0.515761</td>\n",
       "      <td>0.455410</td>\n",
       "      <td>0.460856</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.637700</td>\n",
       "      <td>0.708441</td>\n",
       "      <td>0.562071</td>\n",
       "      <td>0.562323</td>\n",
       "      <td>0.529962</td>\n",
       "      <td>0.529969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.537000</td>\n",
       "      <td>0.765357</td>\n",
       "      <td>0.593536</td>\n",
       "      <td>0.569768</td>\n",
       "      <td>0.564180</td>\n",
       "      <td>0.632110</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.385700</td>\n",
       "      <td>1.171266</td>\n",
       "      <td>0.600996</td>\n",
       "      <td>0.600441</td>\n",
       "      <td>0.600700</td>\n",
       "      <td>0.625382</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.212500</td>\n",
       "      <td>1.320062</td>\n",
       "      <td>0.606793</td>\n",
       "      <td>0.598937</td>\n",
       "      <td>0.600524</td>\n",
       "      <td>0.636697</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.129000</td>\n",
       "      <td>1.477925</td>\n",
       "      <td>0.603342</td>\n",
       "      <td>0.601493</td>\n",
       "      <td>0.602239</td>\n",
       "      <td>0.629052</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.064100</td>\n",
       "      <td>2.359927</td>\n",
       "      <td>0.607790</td>\n",
       "      <td>0.603950</td>\n",
       "      <td>0.605223</td>\n",
       "      <td>0.634862</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.038400</td>\n",
       "      <td>2.185456</td>\n",
       "      <td>0.604927</td>\n",
       "      <td>0.608426</td>\n",
       "      <td>0.605722</td>\n",
       "      <td>0.622324</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.028400</td>\n",
       "      <td>1.922115</td>\n",
       "      <td>0.608083</td>\n",
       "      <td>0.601557</td>\n",
       "      <td>0.603155</td>\n",
       "      <td>0.637003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.019100</td>\n",
       "      <td>2.645177</td>\n",
       "      <td>0.617303</td>\n",
       "      <td>0.621113</td>\n",
       "      <td>0.618292</td>\n",
       "      <td>0.634557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.011800</td>\n",
       "      <td>2.479312</td>\n",
       "      <td>0.610709</td>\n",
       "      <td>0.607412</td>\n",
       "      <td>0.608592</td>\n",
       "      <td>0.637003</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.012200</td>\n",
       "      <td>2.874300</td>\n",
       "      <td>0.615971</td>\n",
       "      <td>0.618069</td>\n",
       "      <td>0.616777</td>\n",
       "      <td>0.636086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.013700</td>\n",
       "      <td>2.737432</td>\n",
       "      <td>0.600169</td>\n",
       "      <td>0.605905</td>\n",
       "      <td>0.598487</td>\n",
       "      <td>0.608563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.008300</td>\n",
       "      <td>3.710084</td>\n",
       "      <td>0.608649</td>\n",
       "      <td>0.611500</td>\n",
       "      <td>0.609527</td>\n",
       "      <td>0.627523</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.014100</td>\n",
       "      <td>2.429986</td>\n",
       "      <td>0.604040</td>\n",
       "      <td>0.606033</td>\n",
       "      <td>0.604768</td>\n",
       "      <td>0.624465</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.006200</td>\n",
       "      <td>2.543256</td>\n",
       "      <td>0.608537</td>\n",
       "      <td>0.610056</td>\n",
       "      <td>0.609158</td>\n",
       "      <td>0.629664</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.004500</td>\n",
       "      <td>2.799902</td>\n",
       "      <td>0.606877</td>\n",
       "      <td>0.600186</td>\n",
       "      <td>0.601762</td>\n",
       "      <td>0.636086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.004100</td>\n",
       "      <td>3.231751</td>\n",
       "      <td>0.600998</td>\n",
       "      <td>0.603170</td>\n",
       "      <td>0.601739</td>\n",
       "      <td>0.621101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.003000</td>\n",
       "      <td>3.184333</td>\n",
       "      <td>0.608555</td>\n",
       "      <td>0.610285</td>\n",
       "      <td>0.609238</td>\n",
       "      <td>0.629358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.005900</td>\n",
       "      <td>2.406815</td>\n",
       "      <td>0.612717</td>\n",
       "      <td>0.618095</td>\n",
       "      <td>0.612915</td>\n",
       "      <td>0.625688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.005200</td>\n",
       "      <td>2.415284</td>\n",
       "      <td>0.608557</td>\n",
       "      <td>0.607996</td>\n",
       "      <td>0.608261</td>\n",
       "      <td>0.632416</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.009700</td>\n",
       "      <td>2.373675</td>\n",
       "      <td>0.606211</td>\n",
       "      <td>0.605871</td>\n",
       "      <td>0.606034</td>\n",
       "      <td>0.629969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.005100</td>\n",
       "      <td>2.938803</td>\n",
       "      <td>0.607885</td>\n",
       "      <td>0.612982</td>\n",
       "      <td>0.608036</td>\n",
       "      <td>0.621101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.002800</td>\n",
       "      <td>3.602328</td>\n",
       "      <td>0.609709</td>\n",
       "      <td>0.608135</td>\n",
       "      <td>0.608804</td>\n",
       "      <td>0.634557</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.003700</td>\n",
       "      <td>2.854026</td>\n",
       "      <td>0.603647</td>\n",
       "      <td>0.605928</td>\n",
       "      <td>0.604423</td>\n",
       "      <td>0.623547</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.005300</td>\n",
       "      <td>2.750162</td>\n",
       "      <td>0.606285</td>\n",
       "      <td>0.603723</td>\n",
       "      <td>0.604691</td>\n",
       "      <td>0.632416</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.003100</td>\n",
       "      <td>3.578363</td>\n",
       "      <td>0.604221</td>\n",
       "      <td>0.605593</td>\n",
       "      <td>0.604787</td>\n",
       "      <td>0.625688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.004000</td>\n",
       "      <td>2.928020</td>\n",
       "      <td>0.597598</td>\n",
       "      <td>0.600571</td>\n",
       "      <td>0.598327</td>\n",
       "      <td>0.615902</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.001500</td>\n",
       "      <td>3.652220</td>\n",
       "      <td>0.604633</td>\n",
       "      <td>0.603022</td>\n",
       "      <td>0.603693</td>\n",
       "      <td>0.629969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>4.131946</td>\n",
       "      <td>0.601085</td>\n",
       "      <td>0.602271</td>\n",
       "      <td>0.601584</td>\n",
       "      <td>0.622936</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>3.794489</td>\n",
       "      <td>0.601416</td>\n",
       "      <td>0.603803</td>\n",
       "      <td>0.602188</td>\n",
       "      <td>0.621101</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.000100</td>\n",
       "      <td>4.118119</td>\n",
       "      <td>0.599100</td>\n",
       "      <td>0.600920</td>\n",
       "      <td>0.599767</td>\n",
       "      <td>0.619878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.245060</td>\n",
       "      <td>0.599100</td>\n",
       "      <td>0.600920</td>\n",
       "      <td>0.599767</td>\n",
       "      <td>0.619878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.340887</td>\n",
       "      <td>0.598729</td>\n",
       "      <td>0.600516</td>\n",
       "      <td>0.599387</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.419894</td>\n",
       "      <td>0.599264</td>\n",
       "      <td>0.601007</td>\n",
       "      <td>0.599915</td>\n",
       "      <td>0.620183</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.486254</td>\n",
       "      <td>0.598893</td>\n",
       "      <td>0.600603</td>\n",
       "      <td>0.599535</td>\n",
       "      <td>0.619878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.547012</td>\n",
       "      <td>0.598522</td>\n",
       "      <td>0.600199</td>\n",
       "      <td>0.599155</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.598845</td>\n",
       "      <td>0.598522</td>\n",
       "      <td>0.600199</td>\n",
       "      <td>0.599155</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.645165</td>\n",
       "      <td>0.598522</td>\n",
       "      <td>0.600199</td>\n",
       "      <td>0.599155</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.687711</td>\n",
       "      <td>0.598522</td>\n",
       "      <td>0.600199</td>\n",
       "      <td>0.599155</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.727278</td>\n",
       "      <td>0.598419</td>\n",
       "      <td>0.600041</td>\n",
       "      <td>0.599038</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.762824</td>\n",
       "      <td>0.598419</td>\n",
       "      <td>0.600041</td>\n",
       "      <td>0.599038</td>\n",
       "      <td>0.619572</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.795006</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.825113</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.853432</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.879235</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.903010</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.924873</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.945240</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.964345</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.981616</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.998157</td>\n",
       "      <td>0.597676</td>\n",
       "      <td>0.599232</td>\n",
       "      <td>0.598276</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.013163</td>\n",
       "      <td>0.596933</td>\n",
       "      <td>0.598424</td>\n",
       "      <td>0.597513</td>\n",
       "      <td>0.618349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.026835</td>\n",
       "      <td>0.596933</td>\n",
       "      <td>0.598424</td>\n",
       "      <td>0.597513</td>\n",
       "      <td>0.618349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.039532</td>\n",
       "      <td>0.597201</td>\n",
       "      <td>0.598670</td>\n",
       "      <td>0.597777</td>\n",
       "      <td>0.618654</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.050751</td>\n",
       "      <td>0.597201</td>\n",
       "      <td>0.598670</td>\n",
       "      <td>0.597777</td>\n",
       "      <td>0.618654</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.061120</td>\n",
       "      <td>0.597201</td>\n",
       "      <td>0.598670</td>\n",
       "      <td>0.597777</td>\n",
       "      <td>0.618654</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.070796</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.079347</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.086469</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.093033</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6400</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.098839</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6500</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.103270</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6600</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.107119</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6700</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.109893</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6800</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.112200</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6900</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.113716</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.114604</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7100</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.115117</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7200</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.115268</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7300</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.115300</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7400</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>5.115300</td>\n",
       "      <td>0.597470</td>\n",
       "      <td>0.598916</td>\n",
       "      <td>0.598040</td>\n",
       "      <td>0.618960</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, f\"{metric.capitalize()} is\", len(result))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=7400, training_loss=0.05715395251987158, metrics={'train_runtime': 2496.2201, 'train_samples_per_second': 377.651, 'train_steps_per_second': 2.964, 'total_flos': 2.4757440510726504e+16, 'train_loss': 0.05715395251987158, 'epoch': 100.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5363c025",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
