{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e27807b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Downloading data: 100%|██████████| 1.75M/1.75M [00:00<00:00, 121MB/s]\n",
      "Generating train split: 7232 examples [00:00, 144847.34 examples/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['id', 'corpus', 'sentence', 'token', 'complexity'],\n",
      "        num_rows: 7232\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Downloading data: 100%|██████████| 212k/212k [00:00<00:00, 84.2MB/s]\n",
      "Generating train split: 887 examples [00:00, 92518.34 examples/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['id', 'corpus', 'sentence', 'token', 'complexity'],\n",
      "        num_rows: 887\n",
      "    })\n",
      "})\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "# URL of the TSV file\n",
    "url = \"https://raw.githubusercontent.com/MMU-TDMLab/CompLex/master/train/lcp_single_train.tsv\"\n",
    "test_url = \"https://raw.githubusercontent.com/MMU-TDMLab/CompLex/refs/heads/master/test-labels/lcp_single_test.tsv\"\n",
    "# Load the TSV file using the csv format\n",
    "train_data = load_dataset(\n",
    "    \"csv\",\n",
    "    data_files=url,\n",
    "    delimiter=\"\\t\"  # Specify tab-separated values\n",
    ")\n",
    "\n",
    "# Inspect the dataset\n",
    "print(train_data)\n",
    "\n",
    "\n",
    "val_data = load_dataset(\n",
    "    \"csv\",\n",
    "    data_files=test_url,\n",
    "    delimiter=\"\\t\"  # Specify tab-separated values\n",
    ")\n",
    "\n",
    "# Inspect the dataset\n",
    "print(val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8d0bb7c1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-360M and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "# from modeling import CLMSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "model_name = \"HuggingFaceTB/SmolLM2-360M\"\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1).to('cuda')\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "12f358b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 7232/7232 [00:00<00:00, 15408.79 examples/s]\n",
      "Map: 100%|██████████| 887/887 [00:00<00:00, 14765.86 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Sentence: {data_point['sentence']} # Word: {data_point['token']} # Output: The complexity score between word and output is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['complexity'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "train_data = train_data['train'].map(add_label_column)\n",
    "val_data = val_data['train'].map(add_label_column)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3544cae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 7232/7232 [00:00<00:00, 17195.72 examples/s]\n",
      "Map: 100%|██████████| 887/887 [00:00<00:00, 17380.58 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete = ['id', 'corpus', 'sentence', 'token', 'complexity']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a70b43e9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# Sentence: Seven days you shall eat unleavened bread, as I commanded you, at the time appointed in the month Abib; for in the month Abib you came out from Egypt. # Word: days # Output: The complexity score between word and output is'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f39dcdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b64eb4e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=5e-4,\n",
    "    per_device_train_batch_size=4,\n",
    "    per_device_eval_batch_size=4,\n",
    "    gradient_accumulation_steps= 4,\n",
    "    num_train_epochs=10,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5b333893",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='4520' max='4520' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [4520/4520 58:56, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.032300</td>\n",
       "      <td>0.021447</td>\n",
       "      <td>0.107797</td>\n",
       "      <td>0.021447</td>\n",
       "      <td>0.146447</td>\n",
       "      <td>0.590755</td>\n",
       "      <td>-0.307908</td>\n",
       "      <td>0.295041</td>\n",
       "      <td>0.328049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.023000</td>\n",
       "      <td>0.031626</td>\n",
       "      <td>0.146031</td>\n",
       "      <td>0.031626</td>\n",
       "      <td>0.177837</td>\n",
       "      <td>0.372041</td>\n",
       "      <td>-0.928672</td>\n",
       "      <td>0.493794</td>\n",
       "      <td>0.503264</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.019200</td>\n",
       "      <td>0.012293</td>\n",
       "      <td>0.084945</td>\n",
       "      <td>0.012293</td>\n",
       "      <td>0.110873</td>\n",
       "      <td>0.670800</td>\n",
       "      <td>0.250331</td>\n",
       "      <td>0.583101</td>\n",
       "      <td>0.577452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.015400</td>\n",
       "      <td>0.019925</td>\n",
       "      <td>0.112476</td>\n",
       "      <td>0.019925</td>\n",
       "      <td>0.141156</td>\n",
       "      <td>0.521984</td>\n",
       "      <td>-0.215117</td>\n",
       "      <td>0.604777</td>\n",
       "      <td>0.605910</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.013500</td>\n",
       "      <td>0.008433</td>\n",
       "      <td>0.070714</td>\n",
       "      <td>0.008433</td>\n",
       "      <td>0.091829</td>\n",
       "      <td>0.738444</td>\n",
       "      <td>0.485747</td>\n",
       "      <td>0.700656</td>\n",
       "      <td>0.671422</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.016300</td>\n",
       "      <td>0.027296</td>\n",
       "      <td>0.146153</td>\n",
       "      <td>0.027296</td>\n",
       "      <td>0.165215</td>\n",
       "      <td>0.284104</td>\n",
       "      <td>-0.664613</td>\n",
       "      <td>0.698331</td>\n",
       "      <td>0.690801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.013900</td>\n",
       "      <td>0.009710</td>\n",
       "      <td>0.075262</td>\n",
       "      <td>0.009710</td>\n",
       "      <td>0.098542</td>\n",
       "      <td>0.721533</td>\n",
       "      <td>0.407818</td>\n",
       "      <td>0.723524</td>\n",
       "      <td>0.702258</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.010500</td>\n",
       "      <td>0.009131</td>\n",
       "      <td>0.072410</td>\n",
       "      <td>0.009131</td>\n",
       "      <td>0.095556</td>\n",
       "      <td>0.733935</td>\n",
       "      <td>0.443153</td>\n",
       "      <td>0.693754</td>\n",
       "      <td>0.707767</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.012400</td>\n",
       "      <td>0.008560</td>\n",
       "      <td>0.072183</td>\n",
       "      <td>0.008560</td>\n",
       "      <td>0.092523</td>\n",
       "      <td>0.744081</td>\n",
       "      <td>0.477949</td>\n",
       "      <td>0.720310</td>\n",
       "      <td>0.709196</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.009100</td>\n",
       "      <td>0.008010</td>\n",
       "      <td>0.071247</td>\n",
       "      <td>0.008010</td>\n",
       "      <td>0.089501</td>\n",
       "      <td>0.741826</td>\n",
       "      <td>0.511493</td>\n",
       "      <td>0.765971</td>\n",
       "      <td>0.720579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.011300</td>\n",
       "      <td>0.008925</td>\n",
       "      <td>0.072447</td>\n",
       "      <td>0.008925</td>\n",
       "      <td>0.094474</td>\n",
       "      <td>0.745209</td>\n",
       "      <td>0.455693</td>\n",
       "      <td>0.733364</td>\n",
       "      <td>0.717992</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.008700</td>\n",
       "      <td>0.016097</td>\n",
       "      <td>0.106027</td>\n",
       "      <td>0.016097</td>\n",
       "      <td>0.126876</td>\n",
       "      <td>0.508455</td>\n",
       "      <td>0.018313</td>\n",
       "      <td>0.767907</td>\n",
       "      <td>0.721992</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.008200</td>\n",
       "      <td>0.007976</td>\n",
       "      <td>0.069473</td>\n",
       "      <td>0.007976</td>\n",
       "      <td>0.089308</td>\n",
       "      <td>0.756483</td>\n",
       "      <td>0.513592</td>\n",
       "      <td>0.759374</td>\n",
       "      <td>0.724108</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.007200</td>\n",
       "      <td>0.008818</td>\n",
       "      <td>0.073307</td>\n",
       "      <td>0.008818</td>\n",
       "      <td>0.093903</td>\n",
       "      <td>0.740699</td>\n",
       "      <td>0.462256</td>\n",
       "      <td>0.762740</td>\n",
       "      <td>0.728961</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.006600</td>\n",
       "      <td>0.006555</td>\n",
       "      <td>0.062225</td>\n",
       "      <td>0.006555</td>\n",
       "      <td>0.080960</td>\n",
       "      <td>0.795941</td>\n",
       "      <td>0.600278</td>\n",
       "      <td>0.776366</td>\n",
       "      <td>0.722798</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.007400</td>\n",
       "      <td>0.007221</td>\n",
       "      <td>0.064882</td>\n",
       "      <td>0.007221</td>\n",
       "      <td>0.084974</td>\n",
       "      <td>0.786922</td>\n",
       "      <td>0.559662</td>\n",
       "      <td>0.776009</td>\n",
       "      <td>0.731995</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.008400</td>\n",
       "      <td>0.008767</td>\n",
       "      <td>0.075556</td>\n",
       "      <td>0.008767</td>\n",
       "      <td>0.093632</td>\n",
       "      <td>0.720406</td>\n",
       "      <td>0.465360</td>\n",
       "      <td>0.778546</td>\n",
       "      <td>0.731000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.006600</td>\n",
       "      <td>0.007180</td>\n",
       "      <td>0.063156</td>\n",
       "      <td>0.007180</td>\n",
       "      <td>0.084734</td>\n",
       "      <td>0.802706</td>\n",
       "      <td>0.562144</td>\n",
       "      <td>0.759367</td>\n",
       "      <td>0.728134</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.006200</td>\n",
       "      <td>0.008967</td>\n",
       "      <td>0.073374</td>\n",
       "      <td>0.008967</td>\n",
       "      <td>0.094695</td>\n",
       "      <td>0.736189</td>\n",
       "      <td>0.453150</td>\n",
       "      <td>0.777845</td>\n",
       "      <td>0.730942</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.006300</td>\n",
       "      <td>0.007011</td>\n",
       "      <td>0.063248</td>\n",
       "      <td>0.007011</td>\n",
       "      <td>0.083733</td>\n",
       "      <td>0.794814</td>\n",
       "      <td>0.572429</td>\n",
       "      <td>0.774867</td>\n",
       "      <td>0.726700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.006000</td>\n",
       "      <td>0.007731</td>\n",
       "      <td>0.067417</td>\n",
       "      <td>0.007731</td>\n",
       "      <td>0.087928</td>\n",
       "      <td>0.767756</td>\n",
       "      <td>0.528515</td>\n",
       "      <td>0.781062</td>\n",
       "      <td>0.726735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.006000</td>\n",
       "      <td>0.006507</td>\n",
       "      <td>0.061831</td>\n",
       "      <td>0.006507</td>\n",
       "      <td>0.080665</td>\n",
       "      <td>0.800451</td>\n",
       "      <td>0.603187</td>\n",
       "      <td>0.781134</td>\n",
       "      <td>0.727829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.005800</td>\n",
       "      <td>0.010210</td>\n",
       "      <td>0.082072</td>\n",
       "      <td>0.010210</td>\n",
       "      <td>0.101044</td>\n",
       "      <td>0.678692</td>\n",
       "      <td>0.377356</td>\n",
       "      <td>0.774344</td>\n",
       "      <td>0.717394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.006000</td>\n",
       "      <td>0.006817</td>\n",
       "      <td>0.062813</td>\n",
       "      <td>0.006817</td>\n",
       "      <td>0.082563</td>\n",
       "      <td>0.810598</td>\n",
       "      <td>0.584294</td>\n",
       "      <td>0.772123</td>\n",
       "      <td>0.716365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.005200</td>\n",
       "      <td>0.006539</td>\n",
       "      <td>0.061503</td>\n",
       "      <td>0.006539</td>\n",
       "      <td>0.080862</td>\n",
       "      <td>0.815107</td>\n",
       "      <td>0.601246</td>\n",
       "      <td>0.776184</td>\n",
       "      <td>0.721369</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.005000</td>\n",
       "      <td>0.006825</td>\n",
       "      <td>0.063631</td>\n",
       "      <td>0.006825</td>\n",
       "      <td>0.082613</td>\n",
       "      <td>0.795941</td>\n",
       "      <td>0.583784</td>\n",
       "      <td>0.777521</td>\n",
       "      <td>0.722494</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.004800</td>\n",
       "      <td>0.006545</td>\n",
       "      <td>0.061877</td>\n",
       "      <td>0.006545</td>\n",
       "      <td>0.080901</td>\n",
       "      <td>0.801578</td>\n",
       "      <td>0.600858</td>\n",
       "      <td>0.778631</td>\n",
       "      <td>0.722973</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.004600</td>\n",
       "      <td>0.008484</td>\n",
       "      <td>0.071152</td>\n",
       "      <td>0.008484</td>\n",
       "      <td>0.092106</td>\n",
       "      <td>0.746336</td>\n",
       "      <td>0.482638</td>\n",
       "      <td>0.778283</td>\n",
       "      <td>0.720581</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.004700</td>\n",
       "      <td>0.006436</td>\n",
       "      <td>0.061855</td>\n",
       "      <td>0.006436</td>\n",
       "      <td>0.080223</td>\n",
       "      <td>0.804961</td>\n",
       "      <td>0.607520</td>\n",
       "      <td>0.783088</td>\n",
       "      <td>0.722286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.004300</td>\n",
       "      <td>0.006908</td>\n",
       "      <td>0.064497</td>\n",
       "      <td>0.006908</td>\n",
       "      <td>0.083117</td>\n",
       "      <td>0.784667</td>\n",
       "      <td>0.578698</td>\n",
       "      <td>0.777451</td>\n",
       "      <td>0.718106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.004100</td>\n",
       "      <td>0.006572</td>\n",
       "      <td>0.062304</td>\n",
       "      <td>0.006572</td>\n",
       "      <td>0.081071</td>\n",
       "      <td>0.806088</td>\n",
       "      <td>0.599185</td>\n",
       "      <td>0.781499</td>\n",
       "      <td>0.717823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.004200</td>\n",
       "      <td>0.006574</td>\n",
       "      <td>0.062289</td>\n",
       "      <td>0.006574</td>\n",
       "      <td>0.081082</td>\n",
       "      <td>0.801578</td>\n",
       "      <td>0.599078</td>\n",
       "      <td>0.781203</td>\n",
       "      <td>0.719918</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.003800</td>\n",
       "      <td>0.007080</td>\n",
       "      <td>0.064447</td>\n",
       "      <td>0.007080</td>\n",
       "      <td>0.084142</td>\n",
       "      <td>0.783540</td>\n",
       "      <td>0.568237</td>\n",
       "      <td>0.778695</td>\n",
       "      <td>0.714007</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.003900</td>\n",
       "      <td>0.006641</td>\n",
       "      <td>0.062338</td>\n",
       "      <td>0.006641</td>\n",
       "      <td>0.081495</td>\n",
       "      <td>0.799324</td>\n",
       "      <td>0.594976</td>\n",
       "      <td>0.775887</td>\n",
       "      <td>0.715136</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.003600</td>\n",
       "      <td>0.006577</td>\n",
       "      <td>0.062137</td>\n",
       "      <td>0.006577</td>\n",
       "      <td>0.081097</td>\n",
       "      <td>0.800451</td>\n",
       "      <td>0.598928</td>\n",
       "      <td>0.776203</td>\n",
       "      <td>0.715612</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.003800</td>\n",
       "      <td>0.006696</td>\n",
       "      <td>0.063143</td>\n",
       "      <td>0.006696</td>\n",
       "      <td>0.081828</td>\n",
       "      <td>0.799324</td>\n",
       "      <td>0.591664</td>\n",
       "      <td>0.775933</td>\n",
       "      <td>0.710297</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.003400</td>\n",
       "      <td>0.007143</td>\n",
       "      <td>0.065113</td>\n",
       "      <td>0.007143</td>\n",
       "      <td>0.084515</td>\n",
       "      <td>0.786922</td>\n",
       "      <td>0.564404</td>\n",
       "      <td>0.775362</td>\n",
       "      <td>0.710160</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.003300</td>\n",
       "      <td>0.006664</td>\n",
       "      <td>0.062877</td>\n",
       "      <td>0.006664</td>\n",
       "      <td>0.081633</td>\n",
       "      <td>0.794814</td>\n",
       "      <td>0.593608</td>\n",
       "      <td>0.776257</td>\n",
       "      <td>0.712062</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.003500</td>\n",
       "      <td>0.006675</td>\n",
       "      <td>0.062867</td>\n",
       "      <td>0.006675</td>\n",
       "      <td>0.081703</td>\n",
       "      <td>0.792559</td>\n",
       "      <td>0.592912</td>\n",
       "      <td>0.773013</td>\n",
       "      <td>0.707193</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.003500</td>\n",
       "      <td>0.006679</td>\n",
       "      <td>0.062748</td>\n",
       "      <td>0.006679</td>\n",
       "      <td>0.081727</td>\n",
       "      <td>0.791432</td>\n",
       "      <td>0.592670</td>\n",
       "      <td>0.774576</td>\n",
       "      <td>0.708864</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.003200</td>\n",
       "      <td>0.006836</td>\n",
       "      <td>0.063458</td>\n",
       "      <td>0.006836</td>\n",
       "      <td>0.082681</td>\n",
       "      <td>0.794814</td>\n",
       "      <td>0.583103</td>\n",
       "      <td>0.772958</td>\n",
       "      <td>0.705839</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.002900</td>\n",
       "      <td>0.006835</td>\n",
       "      <td>0.063421</td>\n",
       "      <td>0.006835</td>\n",
       "      <td>0.082675</td>\n",
       "      <td>0.794814</td>\n",
       "      <td>0.583160</td>\n",
       "      <td>0.772311</td>\n",
       "      <td>0.705416</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.003200</td>\n",
       "      <td>0.006820</td>\n",
       "      <td>0.063374</td>\n",
       "      <td>0.006820</td>\n",
       "      <td>0.082586</td>\n",
       "      <td>0.790304</td>\n",
       "      <td>0.584063</td>\n",
       "      <td>0.772455</td>\n",
       "      <td>0.705394</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.003300</td>\n",
       "      <td>0.006769</td>\n",
       "      <td>0.063148</td>\n",
       "      <td>0.006769</td>\n",
       "      <td>0.082277</td>\n",
       "      <td>0.792559</td>\n",
       "      <td>0.587169</td>\n",
       "      <td>0.772629</td>\n",
       "      <td>0.705708</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.003200</td>\n",
       "      <td>0.006773</td>\n",
       "      <td>0.063161</td>\n",
       "      <td>0.006773</td>\n",
       "      <td>0.082297</td>\n",
       "      <td>0.793687</td>\n",
       "      <td>0.586968</td>\n",
       "      <td>0.772631</td>\n",
       "      <td>0.705743</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=4520, training_loss=0.007707832526184817, metrics={'train_runtime': 3537.8688, 'train_samples_per_second': 20.442, 'train_steps_per_second': 1.278, 'total_flos': 32964070210560.0, 'train_loss': 0.007707832526184817, 'epoch': 10.0})"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07cbde41",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
