{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification\n",
    "from huggingface_hub import login\n",
    "\n",
    "# Log in using your Hugging Face token\n",
    "login(\"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\")\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('SemEvalWorkshop/humicroedit', 'subtask-1')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/SmolLM2-135M and are newly initialized: ['score.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "#model_name = \"openai-community/gpt2-medium\"\n",
    "model_name = \"HuggingFaceTB/SmolLM2-135M\"\n",
    "#config.num_labels=2\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "#tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1).to('cuda')\n",
    "\n",
    "model.config.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='row', rank=3) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
      "    num_rows: 9652\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
      "    num_rows: 2419\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Original Headline: {data_point['original']}. # Edited Headline: {data_point['edit']} # Output: The mean funniness score is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['meanGrade'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
       "    num_rows: 9652\n",
       "})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 9652/9652 [00:02<00:00, 3464.23 examples/s]\n",
      "Map: 100%|██████████| 2419/2419 [00:00<00:00, 3231.61 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'original', 'edit', 'grades', 'meanGrade', 'input']\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# Original Headline: Dozens dead in possible gas <attack/> in Syria ; regime denies allegation. # Edited Headline: bloating # Output: The mean funniness score is'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'original', 'edit', 'grades', 'meanGrade', 'labels', 'input'],\n",
       "    num_rows: 2419\n",
       "})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "81"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=6e-4,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=10,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6040' max='6040' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6040/6040 20:41, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.444800</td>\n",
       "      <td>0.339354</td>\n",
       "      <td>0.462291</td>\n",
       "      <td>0.339354</td>\n",
       "      <td>0.582541</td>\n",
       "      <td>0.128152</td>\n",
       "      <td>-0.014373</td>\n",
       "      <td>0.188231</td>\n",
       "      <td>0.184836</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.372400</td>\n",
       "      <td>0.320095</td>\n",
       "      <td>0.463861</td>\n",
       "      <td>0.320095</td>\n",
       "      <td>0.565770</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>0.043195</td>\n",
       "      <td>0.220797</td>\n",
       "      <td>0.208801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.386600</td>\n",
       "      <td>0.312726</td>\n",
       "      <td>0.458634</td>\n",
       "      <td>0.312726</td>\n",
       "      <td>0.559219</td>\n",
       "      <td>0.122778</td>\n",
       "      <td>0.065223</td>\n",
       "      <td>0.260364</td>\n",
       "      <td>0.251518</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.348200</td>\n",
       "      <td>0.319010</td>\n",
       "      <td>0.446098</td>\n",
       "      <td>0.319010</td>\n",
       "      <td>0.564810</td>\n",
       "      <td>0.135180</td>\n",
       "      <td>0.046437</td>\n",
       "      <td>0.348196</td>\n",
       "      <td>0.346679</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.351600</td>\n",
       "      <td>0.457294</td>\n",
       "      <td>0.572004</td>\n",
       "      <td>0.457294</td>\n",
       "      <td>0.676236</td>\n",
       "      <td>0.087640</td>\n",
       "      <td>-0.366912</td>\n",
       "      <td>0.368479</td>\n",
       "      <td>0.368128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.302900</td>\n",
       "      <td>0.304419</td>\n",
       "      <td>0.441398</td>\n",
       "      <td>0.304419</td>\n",
       "      <td>0.551741</td>\n",
       "      <td>0.132286</td>\n",
       "      <td>0.090054</td>\n",
       "      <td>0.338718</td>\n",
       "      <td>0.328955</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.333800</td>\n",
       "      <td>0.309634</td>\n",
       "      <td>0.460365</td>\n",
       "      <td>0.309634</td>\n",
       "      <td>0.556447</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.074465</td>\n",
       "      <td>0.339494</td>\n",
       "      <td>0.329715</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.337200</td>\n",
       "      <td>0.289148</td>\n",
       "      <td>0.433104</td>\n",
       "      <td>0.289148</td>\n",
       "      <td>0.537725</td>\n",
       "      <td>0.136833</td>\n",
       "      <td>0.135701</td>\n",
       "      <td>0.371155</td>\n",
       "      <td>0.360933</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.303300</td>\n",
       "      <td>0.304589</td>\n",
       "      <td>0.452128</td>\n",
       "      <td>0.304589</td>\n",
       "      <td>0.551896</td>\n",
       "      <td>0.138074</td>\n",
       "      <td>0.089544</td>\n",
       "      <td>0.372695</td>\n",
       "      <td>0.360916</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.306800</td>\n",
       "      <td>0.303788</td>\n",
       "      <td>0.452031</td>\n",
       "      <td>0.303788</td>\n",
       "      <td>0.551170</td>\n",
       "      <td>0.136833</td>\n",
       "      <td>0.091939</td>\n",
       "      <td>0.363702</td>\n",
       "      <td>0.342302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.316600</td>\n",
       "      <td>0.285615</td>\n",
       "      <td>0.425702</td>\n",
       "      <td>0.285615</td>\n",
       "      <td>0.534429</td>\n",
       "      <td>0.140554</td>\n",
       "      <td>0.146261</td>\n",
       "      <td>0.402460</td>\n",
       "      <td>0.393607</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.287000</td>\n",
       "      <td>0.313452</td>\n",
       "      <td>0.437687</td>\n",
       "      <td>0.313452</td>\n",
       "      <td>0.559868</td>\n",
       "      <td>0.140141</td>\n",
       "      <td>0.063053</td>\n",
       "      <td>0.393461</td>\n",
       "      <td>0.389200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.287700</td>\n",
       "      <td>0.290476</td>\n",
       "      <td>0.440315</td>\n",
       "      <td>0.290476</td>\n",
       "      <td>0.538958</td>\n",
       "      <td>0.139314</td>\n",
       "      <td>0.131730</td>\n",
       "      <td>0.395921</td>\n",
       "      <td>0.385334</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.262200</td>\n",
       "      <td>0.288095</td>\n",
       "      <td>0.440190</td>\n",
       "      <td>0.288095</td>\n",
       "      <td>0.536745</td>\n",
       "      <td>0.140141</td>\n",
       "      <td>0.138847</td>\n",
       "      <td>0.402354</td>\n",
       "      <td>0.393917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.296000</td>\n",
       "      <td>0.294490</td>\n",
       "      <td>0.431779</td>\n",
       "      <td>0.294490</td>\n",
       "      <td>0.542669</td>\n",
       "      <td>0.131459</td>\n",
       "      <td>0.119732</td>\n",
       "      <td>0.383588</td>\n",
       "      <td>0.376863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.300100</td>\n",
       "      <td>0.281996</td>\n",
       "      <td>0.427070</td>\n",
       "      <td>0.281996</td>\n",
       "      <td>0.531033</td>\n",
       "      <td>0.140967</td>\n",
       "      <td>0.157077</td>\n",
       "      <td>0.401247</td>\n",
       "      <td>0.389001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.271000</td>\n",
       "      <td>0.304987</td>\n",
       "      <td>0.455943</td>\n",
       "      <td>0.304987</td>\n",
       "      <td>0.552256</td>\n",
       "      <td>0.132699</td>\n",
       "      <td>0.088355</td>\n",
       "      <td>0.382852</td>\n",
       "      <td>0.373791</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.275800</td>\n",
       "      <td>0.294149</td>\n",
       "      <td>0.432392</td>\n",
       "      <td>0.294149</td>\n",
       "      <td>0.542355</td>\n",
       "      <td>0.142621</td>\n",
       "      <td>0.120752</td>\n",
       "      <td>0.394054</td>\n",
       "      <td>0.390897</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.232100</td>\n",
       "      <td>0.299179</td>\n",
       "      <td>0.447951</td>\n",
       "      <td>0.299179</td>\n",
       "      <td>0.546973</td>\n",
       "      <td>0.127739</td>\n",
       "      <td>0.105715</td>\n",
       "      <td>0.395133</td>\n",
       "      <td>0.376449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.273200</td>\n",
       "      <td>0.294637</td>\n",
       "      <td>0.441304</td>\n",
       "      <td>0.294637</td>\n",
       "      <td>0.542804</td>\n",
       "      <td>0.138074</td>\n",
       "      <td>0.119293</td>\n",
       "      <td>0.401291</td>\n",
       "      <td>0.395313</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.248500</td>\n",
       "      <td>0.297095</td>\n",
       "      <td>0.430347</td>\n",
       "      <td>0.297095</td>\n",
       "      <td>0.545064</td>\n",
       "      <td>0.141381</td>\n",
       "      <td>0.111946</td>\n",
       "      <td>0.403672</td>\n",
       "      <td>0.392106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.252500</td>\n",
       "      <td>0.334727</td>\n",
       "      <td>0.477533</td>\n",
       "      <td>0.334727</td>\n",
       "      <td>0.578556</td>\n",
       "      <td>0.126085</td>\n",
       "      <td>-0.000541</td>\n",
       "      <td>0.412147</td>\n",
       "      <td>0.403712</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.243200</td>\n",
       "      <td>0.308343</td>\n",
       "      <td>0.457461</td>\n",
       "      <td>0.308343</td>\n",
       "      <td>0.555287</td>\n",
       "      <td>0.123191</td>\n",
       "      <td>0.078322</td>\n",
       "      <td>0.403594</td>\n",
       "      <td>0.389063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.251800</td>\n",
       "      <td>0.295586</td>\n",
       "      <td>0.446400</td>\n",
       "      <td>0.295586</td>\n",
       "      <td>0.543679</td>\n",
       "      <td>0.130632</td>\n",
       "      <td>0.116455</td>\n",
       "      <td>0.394792</td>\n",
       "      <td>0.378842</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.212400</td>\n",
       "      <td>0.303924</td>\n",
       "      <td>0.449335</td>\n",
       "      <td>0.303924</td>\n",
       "      <td>0.551293</td>\n",
       "      <td>0.136420</td>\n",
       "      <td>0.091534</td>\n",
       "      <td>0.392291</td>\n",
       "      <td>0.381142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.213400</td>\n",
       "      <td>0.319056</td>\n",
       "      <td>0.462101</td>\n",
       "      <td>0.319056</td>\n",
       "      <td>0.564851</td>\n",
       "      <td>0.115337</td>\n",
       "      <td>0.046300</td>\n",
       "      <td>0.394497</td>\n",
       "      <td>0.380666</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.212600</td>\n",
       "      <td>0.303641</td>\n",
       "      <td>0.434294</td>\n",
       "      <td>0.303641</td>\n",
       "      <td>0.551036</td>\n",
       "      <td>0.145101</td>\n",
       "      <td>0.092379</td>\n",
       "      <td>0.397273</td>\n",
       "      <td>0.382735</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.211100</td>\n",
       "      <td>0.295089</td>\n",
       "      <td>0.438539</td>\n",
       "      <td>0.295089</td>\n",
       "      <td>0.543221</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.117941</td>\n",
       "      <td>0.396065</td>\n",
       "      <td>0.379336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.206800</td>\n",
       "      <td>0.292745</td>\n",
       "      <td>0.437889</td>\n",
       "      <td>0.292745</td>\n",
       "      <td>0.541059</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.124949</td>\n",
       "      <td>0.382155</td>\n",
       "      <td>0.367464</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.231200</td>\n",
       "      <td>0.296092</td>\n",
       "      <td>0.434777</td>\n",
       "      <td>0.296093</td>\n",
       "      <td>0.544144</td>\n",
       "      <td>0.146755</td>\n",
       "      <td>0.114942</td>\n",
       "      <td>0.383096</td>\n",
       "      <td>0.372021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.183600</td>\n",
       "      <td>0.326484</td>\n",
       "      <td>0.455902</td>\n",
       "      <td>0.326484</td>\n",
       "      <td>0.571388</td>\n",
       "      <td>0.137247</td>\n",
       "      <td>0.024097</td>\n",
       "      <td>0.376288</td>\n",
       "      <td>0.355832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.176100</td>\n",
       "      <td>0.315921</td>\n",
       "      <td>0.450060</td>\n",
       "      <td>0.315921</td>\n",
       "      <td>0.562068</td>\n",
       "      <td>0.133940</td>\n",
       "      <td>0.055673</td>\n",
       "      <td>0.385505</td>\n",
       "      <td>0.372004</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.188500</td>\n",
       "      <td>0.317862</td>\n",
       "      <td>0.460200</td>\n",
       "      <td>0.317862</td>\n",
       "      <td>0.563793</td>\n",
       "      <td>0.128566</td>\n",
       "      <td>0.049868</td>\n",
       "      <td>0.380915</td>\n",
       "      <td>0.366485</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.173000</td>\n",
       "      <td>0.308996</td>\n",
       "      <td>0.449787</td>\n",
       "      <td>0.308996</td>\n",
       "      <td>0.555874</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.076372</td>\n",
       "      <td>0.384816</td>\n",
       "      <td>0.369412</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.171200</td>\n",
       "      <td>0.305967</td>\n",
       "      <td>0.452118</td>\n",
       "      <td>0.305967</td>\n",
       "      <td>0.553142</td>\n",
       "      <td>0.123605</td>\n",
       "      <td>0.085427</td>\n",
       "      <td>0.378250</td>\n",
       "      <td>0.361823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.173900</td>\n",
       "      <td>0.317106</td>\n",
       "      <td>0.458159</td>\n",
       "      <td>0.317106</td>\n",
       "      <td>0.563122</td>\n",
       "      <td>0.118644</td>\n",
       "      <td>0.052129</td>\n",
       "      <td>0.379650</td>\n",
       "      <td>0.366832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.150000</td>\n",
       "      <td>0.329562</td>\n",
       "      <td>0.463532</td>\n",
       "      <td>0.329562</td>\n",
       "      <td>0.574075</td>\n",
       "      <td>0.126499</td>\n",
       "      <td>0.014898</td>\n",
       "      <td>0.373459</td>\n",
       "      <td>0.362924</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.140900</td>\n",
       "      <td>0.322221</td>\n",
       "      <td>0.455470</td>\n",
       "      <td>0.322221</td>\n",
       "      <td>0.567645</td>\n",
       "      <td>0.132699</td>\n",
       "      <td>0.036841</td>\n",
       "      <td>0.372227</td>\n",
       "      <td>0.364012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.139300</td>\n",
       "      <td>0.339714</td>\n",
       "      <td>0.462480</td>\n",
       "      <td>0.339714</td>\n",
       "      <td>0.582850</td>\n",
       "      <td>0.140554</td>\n",
       "      <td>-0.015449</td>\n",
       "      <td>0.370336</td>\n",
       "      <td>0.358806</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.145000</td>\n",
       "      <td>0.321901</td>\n",
       "      <td>0.464709</td>\n",
       "      <td>0.321901</td>\n",
       "      <td>0.567363</td>\n",
       "      <td>0.111203</td>\n",
       "      <td>0.037798</td>\n",
       "      <td>0.370936</td>\n",
       "      <td>0.354793</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.140300</td>\n",
       "      <td>0.324102</td>\n",
       "      <td>0.456884</td>\n",
       "      <td>0.324102</td>\n",
       "      <td>0.569299</td>\n",
       "      <td>0.129806</td>\n",
       "      <td>0.031219</td>\n",
       "      <td>0.377012</td>\n",
       "      <td>0.362933</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.149600</td>\n",
       "      <td>0.319583</td>\n",
       "      <td>0.456179</td>\n",
       "      <td>0.319583</td>\n",
       "      <td>0.565317</td>\n",
       "      <td>0.125258</td>\n",
       "      <td>0.044724</td>\n",
       "      <td>0.373741</td>\n",
       "      <td>0.361737</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.125400</td>\n",
       "      <td>0.321343</td>\n",
       "      <td>0.458555</td>\n",
       "      <td>0.321343</td>\n",
       "      <td>0.566871</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>0.039465</td>\n",
       "      <td>0.372474</td>\n",
       "      <td>0.360563</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.103900</td>\n",
       "      <td>0.331612</td>\n",
       "      <td>0.461553</td>\n",
       "      <td>0.331612</td>\n",
       "      <td>0.575857</td>\n",
       "      <td>0.126912</td>\n",
       "      <td>0.008770</td>\n",
       "      <td>0.364913</td>\n",
       "      <td>0.352698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.115700</td>\n",
       "      <td>0.339531</td>\n",
       "      <td>0.466737</td>\n",
       "      <td>0.339531</td>\n",
       "      <td>0.582693</td>\n",
       "      <td>0.127325</td>\n",
       "      <td>-0.014903</td>\n",
       "      <td>0.360697</td>\n",
       "      <td>0.352046</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.116000</td>\n",
       "      <td>0.347691</td>\n",
       "      <td>0.469449</td>\n",
       "      <td>0.347691</td>\n",
       "      <td>0.589653</td>\n",
       "      <td>0.131459</td>\n",
       "      <td>-0.039292</td>\n",
       "      <td>0.359679</td>\n",
       "      <td>0.348906</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.117000</td>\n",
       "      <td>0.334345</td>\n",
       "      <td>0.465822</td>\n",
       "      <td>0.334345</td>\n",
       "      <td>0.578226</td>\n",
       "      <td>0.124018</td>\n",
       "      <td>0.000601</td>\n",
       "      <td>0.363997</td>\n",
       "      <td>0.353892</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.121700</td>\n",
       "      <td>0.333559</td>\n",
       "      <td>0.464895</td>\n",
       "      <td>0.333559</td>\n",
       "      <td>0.577546</td>\n",
       "      <td>0.117404</td>\n",
       "      <td>0.002950</td>\n",
       "      <td>0.358653</td>\n",
       "      <td>0.347693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.102900</td>\n",
       "      <td>0.342909</td>\n",
       "      <td>0.471394</td>\n",
       "      <td>0.342909</td>\n",
       "      <td>0.585585</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>-0.024999</td>\n",
       "      <td>0.361607</td>\n",
       "      <td>0.350123</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.092200</td>\n",
       "      <td>0.342634</td>\n",
       "      <td>0.471050</td>\n",
       "      <td>0.342634</td>\n",
       "      <td>0.585350</td>\n",
       "      <td>0.117404</td>\n",
       "      <td>-0.024177</td>\n",
       "      <td>0.357441</td>\n",
       "      <td>0.344836</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.099900</td>\n",
       "      <td>0.344623</td>\n",
       "      <td>0.472897</td>\n",
       "      <td>0.344623</td>\n",
       "      <td>0.587046</td>\n",
       "      <td>0.121538</td>\n",
       "      <td>-0.030123</td>\n",
       "      <td>0.356781</td>\n",
       "      <td>0.344968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.093200</td>\n",
       "      <td>0.341451</td>\n",
       "      <td>0.470539</td>\n",
       "      <td>0.341451</td>\n",
       "      <td>0.584338</td>\n",
       "      <td>0.118231</td>\n",
       "      <td>-0.020640</td>\n",
       "      <td>0.355983</td>\n",
       "      <td>0.344695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.099000</td>\n",
       "      <td>0.342575</td>\n",
       "      <td>0.471342</td>\n",
       "      <td>0.342575</td>\n",
       "      <td>0.585299</td>\n",
       "      <td>0.117817</td>\n",
       "      <td>-0.024001</td>\n",
       "      <td>0.357360</td>\n",
       "      <td>0.347085</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.101100</td>\n",
       "      <td>0.342264</td>\n",
       "      <td>0.472514</td>\n",
       "      <td>0.342264</td>\n",
       "      <td>0.585033</td>\n",
       "      <td>0.125258</td>\n",
       "      <td>-0.023071</td>\n",
       "      <td>0.358495</td>\n",
       "      <td>0.348520</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.088000</td>\n",
       "      <td>0.343715</td>\n",
       "      <td>0.472221</td>\n",
       "      <td>0.343715</td>\n",
       "      <td>0.586272</td>\n",
       "      <td>0.119471</td>\n",
       "      <td>-0.027409</td>\n",
       "      <td>0.357074</td>\n",
       "      <td>0.346859</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.099000</td>\n",
       "      <td>0.345388</td>\n",
       "      <td>0.472851</td>\n",
       "      <td>0.345388</td>\n",
       "      <td>0.587698</td>\n",
       "      <td>0.116577</td>\n",
       "      <td>-0.032410</td>\n",
       "      <td>0.355948</td>\n",
       "      <td>0.345646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.087100</td>\n",
       "      <td>0.345114</td>\n",
       "      <td>0.473582</td>\n",
       "      <td>0.345114</td>\n",
       "      <td>0.587464</td>\n",
       "      <td>0.116577</td>\n",
       "      <td>-0.031591</td>\n",
       "      <td>0.356062</td>\n",
       "      <td>0.345926</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.086000</td>\n",
       "      <td>0.345319</td>\n",
       "      <td>0.473373</td>\n",
       "      <td>0.345319</td>\n",
       "      <td>0.587639</td>\n",
       "      <td>0.118644</td>\n",
       "      <td>-0.032204</td>\n",
       "      <td>0.355999</td>\n",
       "      <td>0.345866</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.087900</td>\n",
       "      <td>0.345500</td>\n",
       "      <td>0.473543</td>\n",
       "      <td>0.345500</td>\n",
       "      <td>0.587792</td>\n",
       "      <td>0.118644</td>\n",
       "      <td>-0.032744</td>\n",
       "      <td>0.355915</td>\n",
       "      <td>0.345780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.082300</td>\n",
       "      <td>0.345540</td>\n",
       "      <td>0.473702</td>\n",
       "      <td>0.345540</td>\n",
       "      <td>0.587827</td>\n",
       "      <td>0.118644</td>\n",
       "      <td>-0.032864</td>\n",
       "      <td>0.355906</td>\n",
       "      <td>0.345802</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6040, training_loss=0.20465247362654732, metrics={'train_runtime': 1242.5096, 'train_samples_per_second': 77.681, 'train_steps_per_second': 4.861, 'total_flos': 13980608308224.0, 'train_loss': 0.20465247362654732, 'epoch': 10.0})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc4e83df",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
