{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e27807b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('RobZamp/sick')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2c0d83c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification\n",
    "from huggingface_hub import login\n",
    "\n",
    "# Log in using your Hugging Face token\n",
    "login(\"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\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8d0bb7c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-360M and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "# from modeling import CLMSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "model_name = \"HuggingFaceTB/SmolLM2-360M\"\n",
    "#config.num_labels=2\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "#tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1).to('cuda')\n",
    "\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "12f358b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 4439\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 495\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Sentence-1:: {data_point['sentence_A']}. # Sentence-2: {data_point['sentence_B']} # Output: The similarity is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['relatedness_score'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3544cae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 4439/4439 [00:01<00:00, 3384.00 examples/s]\n",
      "Map: 100%|██████████| 495/495 [00:00<00:00, 3168.21 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'input']\n",
    "\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a70b43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# Sentence-1:: A person on a black motorbike is doing tricks with a jacket. # Sentence-2: A person is riding the bicycle on one wheel # Output: The similarity is'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f39dcdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b64eb4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=1e-3,\n",
    "    per_device_train_batch_size=14,\n",
    "    per_device_eval_batch_size=64,\n",
    "    gradient_accumulation_steps= 1,\n",
    "    num_train_epochs=20,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5b333893",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6360' max='6360' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6360/6360 37:26, Epoch 20/20]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>4.345600</td>\n",
       "      <td>0.399384</td>\n",
       "      <td>0.491662</td>\n",
       "      <td>0.399384</td>\n",
       "      <td>0.631968</td>\n",
       "      <td>0.129293</td>\n",
       "      <td>0.605969</td>\n",
       "      <td>0.832601</td>\n",
       "      <td>0.796561</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.428300</td>\n",
       "      <td>0.516170</td>\n",
       "      <td>0.566812</td>\n",
       "      <td>0.516170</td>\n",
       "      <td>0.718450</td>\n",
       "      <td>0.129293</td>\n",
       "      <td>0.490748</td>\n",
       "      <td>0.856817</td>\n",
       "      <td>0.811915</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.305200</td>\n",
       "      <td>0.242989</td>\n",
       "      <td>0.391109</td>\n",
       "      <td>0.242989</td>\n",
       "      <td>0.492939</td>\n",
       "      <td>0.163636</td>\n",
       "      <td>0.760268</td>\n",
       "      <td>0.874685</td>\n",
       "      <td>0.830050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.252000</td>\n",
       "      <td>0.328626</td>\n",
       "      <td>0.459997</td>\n",
       "      <td>0.328626</td>\n",
       "      <td>0.573259</td>\n",
       "      <td>0.135354</td>\n",
       "      <td>0.675778</td>\n",
       "      <td>0.865488</td>\n",
       "      <td>0.814743</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.273900</td>\n",
       "      <td>0.208186</td>\n",
       "      <td>0.370459</td>\n",
       "      <td>0.208186</td>\n",
       "      <td>0.456274</td>\n",
       "      <td>0.145455</td>\n",
       "      <td>0.794604</td>\n",
       "      <td>0.897377</td>\n",
       "      <td>0.849892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.213900</td>\n",
       "      <td>0.233300</td>\n",
       "      <td>0.381591</td>\n",
       "      <td>0.233300</td>\n",
       "      <td>0.483011</td>\n",
       "      <td>0.171717</td>\n",
       "      <td>0.769827</td>\n",
       "      <td>0.899399</td>\n",
       "      <td>0.865580</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.197300</td>\n",
       "      <td>0.326512</td>\n",
       "      <td>0.465491</td>\n",
       "      <td>0.326512</td>\n",
       "      <td>0.571412</td>\n",
       "      <td>0.143434</td>\n",
       "      <td>0.677864</td>\n",
       "      <td>0.882677</td>\n",
       "      <td>0.839024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.169300</td>\n",
       "      <td>0.214368</td>\n",
       "      <td>0.360209</td>\n",
       "      <td>0.214368</td>\n",
       "      <td>0.462999</td>\n",
       "      <td>0.185859</td>\n",
       "      <td>0.788505</td>\n",
       "      <td>0.889272</td>\n",
       "      <td>0.845726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.173000</td>\n",
       "      <td>0.220384</td>\n",
       "      <td>0.362961</td>\n",
       "      <td>0.220384</td>\n",
       "      <td>0.469451</td>\n",
       "      <td>0.204040</td>\n",
       "      <td>0.782570</td>\n",
       "      <td>0.897686</td>\n",
       "      <td>0.859824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.164300</td>\n",
       "      <td>0.233805</td>\n",
       "      <td>0.382315</td>\n",
       "      <td>0.233805</td>\n",
       "      <td>0.483534</td>\n",
       "      <td>0.183838</td>\n",
       "      <td>0.769328</td>\n",
       "      <td>0.899740</td>\n",
       "      <td>0.864695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.129500</td>\n",
       "      <td>0.202935</td>\n",
       "      <td>0.357227</td>\n",
       "      <td>0.202935</td>\n",
       "      <td>0.450483</td>\n",
       "      <td>0.169697</td>\n",
       "      <td>0.799785</td>\n",
       "      <td>0.896192</td>\n",
       "      <td>0.860704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.155500</td>\n",
       "      <td>0.206772</td>\n",
       "      <td>0.351195</td>\n",
       "      <td>0.206772</td>\n",
       "      <td>0.454722</td>\n",
       "      <td>0.189899</td>\n",
       "      <td>0.795999</td>\n",
       "      <td>0.903053</td>\n",
       "      <td>0.872519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.161100</td>\n",
       "      <td>0.203963</td>\n",
       "      <td>0.350732</td>\n",
       "      <td>0.203963</td>\n",
       "      <td>0.451622</td>\n",
       "      <td>0.191919</td>\n",
       "      <td>0.798771</td>\n",
       "      <td>0.896345</td>\n",
       "      <td>0.861521</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.118900</td>\n",
       "      <td>0.233764</td>\n",
       "      <td>0.385907</td>\n",
       "      <td>0.233764</td>\n",
       "      <td>0.483492</td>\n",
       "      <td>0.161616</td>\n",
       "      <td>0.769369</td>\n",
       "      <td>0.890002</td>\n",
       "      <td>0.856172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.119300</td>\n",
       "      <td>0.226199</td>\n",
       "      <td>0.369791</td>\n",
       "      <td>0.226199</td>\n",
       "      <td>0.475603</td>\n",
       "      <td>0.177778</td>\n",
       "      <td>0.776833</td>\n",
       "      <td>0.886840</td>\n",
       "      <td>0.839032</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.129300</td>\n",
       "      <td>0.207850</td>\n",
       "      <td>0.350489</td>\n",
       "      <td>0.207850</td>\n",
       "      <td>0.455906</td>\n",
       "      <td>0.210101</td>\n",
       "      <td>0.794936</td>\n",
       "      <td>0.898237</td>\n",
       "      <td>0.866125</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.104900</td>\n",
       "      <td>0.205811</td>\n",
       "      <td>0.355869</td>\n",
       "      <td>0.205811</td>\n",
       "      <td>0.453664</td>\n",
       "      <td>0.187879</td>\n",
       "      <td>0.796947</td>\n",
       "      <td>0.893798</td>\n",
       "      <td>0.855642</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.112500</td>\n",
       "      <td>0.227638</td>\n",
       "      <td>0.360461</td>\n",
       "      <td>0.227638</td>\n",
       "      <td>0.477114</td>\n",
       "      <td>0.206061</td>\n",
       "      <td>0.775413</td>\n",
       "      <td>0.896072</td>\n",
       "      <td>0.859641</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.111800</td>\n",
       "      <td>0.339199</td>\n",
       "      <td>0.460484</td>\n",
       "      <td>0.339199</td>\n",
       "      <td>0.582408</td>\n",
       "      <td>0.133333</td>\n",
       "      <td>0.665347</td>\n",
       "      <td>0.894735</td>\n",
       "      <td>0.862885</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.106600</td>\n",
       "      <td>0.209515</td>\n",
       "      <td>0.353662</td>\n",
       "      <td>0.209515</td>\n",
       "      <td>0.457728</td>\n",
       "      <td>0.193939</td>\n",
       "      <td>0.793293</td>\n",
       "      <td>0.893967</td>\n",
       "      <td>0.857718</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.085200</td>\n",
       "      <td>0.207520</td>\n",
       "      <td>0.344691</td>\n",
       "      <td>0.207520</td>\n",
       "      <td>0.455544</td>\n",
       "      <td>0.218182</td>\n",
       "      <td>0.795261</td>\n",
       "      <td>0.893575</td>\n",
       "      <td>0.857957</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.099800</td>\n",
       "      <td>0.229150</td>\n",
       "      <td>0.365150</td>\n",
       "      <td>0.229150</td>\n",
       "      <td>0.478696</td>\n",
       "      <td>0.187879</td>\n",
       "      <td>0.773921</td>\n",
       "      <td>0.883309</td>\n",
       "      <td>0.838772</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.072900</td>\n",
       "      <td>0.250153</td>\n",
       "      <td>0.389660</td>\n",
       "      <td>0.250152</td>\n",
       "      <td>0.500152</td>\n",
       "      <td>0.155556</td>\n",
       "      <td>0.753200</td>\n",
       "      <td>0.885017</td>\n",
       "      <td>0.849559</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.080000</td>\n",
       "      <td>0.214413</td>\n",
       "      <td>0.356416</td>\n",
       "      <td>0.214413</td>\n",
       "      <td>0.463047</td>\n",
       "      <td>0.187879</td>\n",
       "      <td>0.788461</td>\n",
       "      <td>0.888330</td>\n",
       "      <td>0.846082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.078600</td>\n",
       "      <td>0.211045</td>\n",
       "      <td>0.353893</td>\n",
       "      <td>0.211045</td>\n",
       "      <td>0.459397</td>\n",
       "      <td>0.185859</td>\n",
       "      <td>0.791783</td>\n",
       "      <td>0.889885</td>\n",
       "      <td>0.847083</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.066400</td>\n",
       "      <td>0.229040</td>\n",
       "      <td>0.366346</td>\n",
       "      <td>0.229040</td>\n",
       "      <td>0.478581</td>\n",
       "      <td>0.197980</td>\n",
       "      <td>0.774030</td>\n",
       "      <td>0.890048</td>\n",
       "      <td>0.852062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.062800</td>\n",
       "      <td>0.215201</td>\n",
       "      <td>0.356552</td>\n",
       "      <td>0.215201</td>\n",
       "      <td>0.463897</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.787683</td>\n",
       "      <td>0.888148</td>\n",
       "      <td>0.845666</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.061200</td>\n",
       "      <td>0.203224</td>\n",
       "      <td>0.354811</td>\n",
       "      <td>0.203224</td>\n",
       "      <td>0.450804</td>\n",
       "      <td>0.175758</td>\n",
       "      <td>0.799499</td>\n",
       "      <td>0.898288</td>\n",
       "      <td>0.860214</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.067000</td>\n",
       "      <td>0.192453</td>\n",
       "      <td>0.336749</td>\n",
       "      <td>0.192453</td>\n",
       "      <td>0.438695</td>\n",
       "      <td>0.224242</td>\n",
       "      <td>0.810126</td>\n",
       "      <td>0.901701</td>\n",
       "      <td>0.869677</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.046600</td>\n",
       "      <td>0.210871</td>\n",
       "      <td>0.356173</td>\n",
       "      <td>0.210871</td>\n",
       "      <td>0.459207</td>\n",
       "      <td>0.202020</td>\n",
       "      <td>0.791955</td>\n",
       "      <td>0.891134</td>\n",
       "      <td>0.849401</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.051900</td>\n",
       "      <td>0.200420</td>\n",
       "      <td>0.343498</td>\n",
       "      <td>0.200420</td>\n",
       "      <td>0.447683</td>\n",
       "      <td>0.214141</td>\n",
       "      <td>0.802266</td>\n",
       "      <td>0.895744</td>\n",
       "      <td>0.853178</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.055000</td>\n",
       "      <td>0.213805</td>\n",
       "      <td>0.362097</td>\n",
       "      <td>0.213805</td>\n",
       "      <td>0.462391</td>\n",
       "      <td>0.171717</td>\n",
       "      <td>0.789060</td>\n",
       "      <td>0.891553</td>\n",
       "      <td>0.847308</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.037800</td>\n",
       "      <td>0.202941</td>\n",
       "      <td>0.343348</td>\n",
       "      <td>0.202941</td>\n",
       "      <td>0.450490</td>\n",
       "      <td>0.210101</td>\n",
       "      <td>0.799779</td>\n",
       "      <td>0.895330</td>\n",
       "      <td>0.852442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.041100</td>\n",
       "      <td>0.210157</td>\n",
       "      <td>0.343460</td>\n",
       "      <td>0.210157</td>\n",
       "      <td>0.458429</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.792660</td>\n",
       "      <td>0.893764</td>\n",
       "      <td>0.853017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.042300</td>\n",
       "      <td>0.232446</td>\n",
       "      <td>0.372060</td>\n",
       "      <td>0.232446</td>\n",
       "      <td>0.482127</td>\n",
       "      <td>0.175758</td>\n",
       "      <td>0.770669</td>\n",
       "      <td>0.886417</td>\n",
       "      <td>0.841122</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.030400</td>\n",
       "      <td>0.203250</td>\n",
       "      <td>0.341710</td>\n",
       "      <td>0.203250</td>\n",
       "      <td>0.450833</td>\n",
       "      <td>0.208081</td>\n",
       "      <td>0.799474</td>\n",
       "      <td>0.894552</td>\n",
       "      <td>0.854211</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.028000</td>\n",
       "      <td>0.207064</td>\n",
       "      <td>0.343697</td>\n",
       "      <td>0.207064</td>\n",
       "      <td>0.455043</td>\n",
       "      <td>0.212121</td>\n",
       "      <td>0.795711</td>\n",
       "      <td>0.897275</td>\n",
       "      <td>0.858696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.032900</td>\n",
       "      <td>0.201608</td>\n",
       "      <td>0.343083</td>\n",
       "      <td>0.201608</td>\n",
       "      <td>0.449007</td>\n",
       "      <td>0.212121</td>\n",
       "      <td>0.801094</td>\n",
       "      <td>0.896660</td>\n",
       "      <td>0.858022</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.021800</td>\n",
       "      <td>0.206363</td>\n",
       "      <td>0.344341</td>\n",
       "      <td>0.206363</td>\n",
       "      <td>0.454272</td>\n",
       "      <td>0.200000</td>\n",
       "      <td>0.796403</td>\n",
       "      <td>0.894257</td>\n",
       "      <td>0.854017</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.025100</td>\n",
       "      <td>0.196349</td>\n",
       "      <td>0.337561</td>\n",
       "      <td>0.196349</td>\n",
       "      <td>0.443113</td>\n",
       "      <td>0.210101</td>\n",
       "      <td>0.806282</td>\n",
       "      <td>0.898274</td>\n",
       "      <td>0.859427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.023300</td>\n",
       "      <td>0.199257</td>\n",
       "      <td>0.341857</td>\n",
       "      <td>0.199257</td>\n",
       "      <td>0.446382</td>\n",
       "      <td>0.224242</td>\n",
       "      <td>0.803413</td>\n",
       "      <td>0.897039</td>\n",
       "      <td>0.857417</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.017600</td>\n",
       "      <td>0.198300</td>\n",
       "      <td>0.340919</td>\n",
       "      <td>0.198300</td>\n",
       "      <td>0.445309</td>\n",
       "      <td>0.220202</td>\n",
       "      <td>0.804358</td>\n",
       "      <td>0.896994</td>\n",
       "      <td>0.858422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.014100</td>\n",
       "      <td>0.200539</td>\n",
       "      <td>0.340634</td>\n",
       "      <td>0.200539</td>\n",
       "      <td>0.447816</td>\n",
       "      <td>0.214141</td>\n",
       "      <td>0.802148</td>\n",
       "      <td>0.896564</td>\n",
       "      <td>0.857918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.015000</td>\n",
       "      <td>0.204993</td>\n",
       "      <td>0.350676</td>\n",
       "      <td>0.204993</td>\n",
       "      <td>0.452762</td>\n",
       "      <td>0.191919</td>\n",
       "      <td>0.797754</td>\n",
       "      <td>0.893613</td>\n",
       "      <td>0.853390</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.014800</td>\n",
       "      <td>0.202123</td>\n",
       "      <td>0.341807</td>\n",
       "      <td>0.202123</td>\n",
       "      <td>0.449581</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>0.800585</td>\n",
       "      <td>0.895172</td>\n",
       "      <td>0.856522</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.011400</td>\n",
       "      <td>0.201536</td>\n",
       "      <td>0.342159</td>\n",
       "      <td>0.201536</td>\n",
       "      <td>0.448927</td>\n",
       "      <td>0.220202</td>\n",
       "      <td>0.801165</td>\n",
       "      <td>0.895578</td>\n",
       "      <td>0.855243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.011200</td>\n",
       "      <td>0.201083</td>\n",
       "      <td>0.342333</td>\n",
       "      <td>0.201083</td>\n",
       "      <td>0.448423</td>\n",
       "      <td>0.226263</td>\n",
       "      <td>0.801612</td>\n",
       "      <td>0.895890</td>\n",
       "      <td>0.857140</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.008800</td>\n",
       "      <td>0.203350</td>\n",
       "      <td>0.341922</td>\n",
       "      <td>0.203350</td>\n",
       "      <td>0.450943</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>0.799375</td>\n",
       "      <td>0.894881</td>\n",
       "      <td>0.854202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.007000</td>\n",
       "      <td>0.200774</td>\n",
       "      <td>0.340907</td>\n",
       "      <td>0.200775</td>\n",
       "      <td>0.448079</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.801916</td>\n",
       "      <td>0.896068</td>\n",
       "      <td>0.856364</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.007800</td>\n",
       "      <td>0.201663</td>\n",
       "      <td>0.340743</td>\n",
       "      <td>0.201663</td>\n",
       "      <td>0.449070</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.801039</td>\n",
       "      <td>0.896172</td>\n",
       "      <td>0.856493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.007100</td>\n",
       "      <td>0.202032</td>\n",
       "      <td>0.342792</td>\n",
       "      <td>0.202032</td>\n",
       "      <td>0.449479</td>\n",
       "      <td>0.218182</td>\n",
       "      <td>0.800676</td>\n",
       "      <td>0.895282</td>\n",
       "      <td>0.855472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.005700</td>\n",
       "      <td>0.202352</td>\n",
       "      <td>0.342853</td>\n",
       "      <td>0.202352</td>\n",
       "      <td>0.449836</td>\n",
       "      <td>0.230303</td>\n",
       "      <td>0.800360</td>\n",
       "      <td>0.895020</td>\n",
       "      <td>0.854820</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.005300</td>\n",
       "      <td>0.203299</td>\n",
       "      <td>0.343829</td>\n",
       "      <td>0.203299</td>\n",
       "      <td>0.450886</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.799426</td>\n",
       "      <td>0.894544</td>\n",
       "      <td>0.854004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.005200</td>\n",
       "      <td>0.203492</td>\n",
       "      <td>0.344027</td>\n",
       "      <td>0.203492</td>\n",
       "      <td>0.451101</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>0.799235</td>\n",
       "      <td>0.894820</td>\n",
       "      <td>0.854718</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.003900</td>\n",
       "      <td>0.201927</td>\n",
       "      <td>0.342532</td>\n",
       "      <td>0.201927</td>\n",
       "      <td>0.449363</td>\n",
       "      <td>0.230303</td>\n",
       "      <td>0.800779</td>\n",
       "      <td>0.895307</td>\n",
       "      <td>0.855470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.004700</td>\n",
       "      <td>0.202234</td>\n",
       "      <td>0.342514</td>\n",
       "      <td>0.202234</td>\n",
       "      <td>0.449704</td>\n",
       "      <td>0.230303</td>\n",
       "      <td>0.800476</td>\n",
       "      <td>0.895111</td>\n",
       "      <td>0.854741</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.004100</td>\n",
       "      <td>0.202706</td>\n",
       "      <td>0.343158</td>\n",
       "      <td>0.202706</td>\n",
       "      <td>0.450229</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.800011</td>\n",
       "      <td>0.895147</td>\n",
       "      <td>0.854549</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.003300</td>\n",
       "      <td>0.202173</td>\n",
       "      <td>0.342812</td>\n",
       "      <td>0.202173</td>\n",
       "      <td>0.449636</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.800537</td>\n",
       "      <td>0.895130</td>\n",
       "      <td>0.854340</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.003900</td>\n",
       "      <td>0.202312</td>\n",
       "      <td>0.342356</td>\n",
       "      <td>0.202312</td>\n",
       "      <td>0.449791</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.800399</td>\n",
       "      <td>0.895238</td>\n",
       "      <td>0.854470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.003300</td>\n",
       "      <td>0.202087</td>\n",
       "      <td>0.342711</td>\n",
       "      <td>0.202087</td>\n",
       "      <td>0.449541</td>\n",
       "      <td>0.238384</td>\n",
       "      <td>0.800621</td>\n",
       "      <td>0.895171</td>\n",
       "      <td>0.854682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.003400</td>\n",
       "      <td>0.202154</td>\n",
       "      <td>0.342510</td>\n",
       "      <td>0.202154</td>\n",
       "      <td>0.449616</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.800555</td>\n",
       "      <td>0.895165</td>\n",
       "      <td>0.854793</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.002900</td>\n",
       "      <td>0.202202</td>\n",
       "      <td>0.342600</td>\n",
       "      <td>0.202202</td>\n",
       "      <td>0.449668</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.800508</td>\n",
       "      <td>0.895167</td>\n",
       "      <td>0.854770</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.002800</td>\n",
       "      <td>0.202150</td>\n",
       "      <td>0.342519</td>\n",
       "      <td>0.202150</td>\n",
       "      <td>0.449611</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>0.800559</td>\n",
       "      <td>0.895175</td>\n",
       "      <td>0.854707</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6360, training_loss=0.1422661987631201, metrics={'train_runtime': 2246.888, 'train_samples_per_second': 39.512, 'train_steps_per_second': 2.831, 'total_flos': 23566403370240.0, 'train_loss': 0.1422661987631201, 'epoch': 20.0})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d3bbea",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
