{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "90b47c6c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n",
    "\n",
    "import numpy as np\n",
    "import requests\n",
    "import pandas as pd\n",
    "from io import StringIO\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "#load train data\n",
    "import pandas as pd\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
    "from torch.utils.data import Dataset\n",
    "import logging\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets  = load_dataset('RobZamp/sick')\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "092be135",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset'],\n",
       "        num_rows: 4439\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset'],\n",
       "        num_rows: 495\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset'],\n",
       "        num_rows: 4906\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "de228bb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/miniconda3/envs/emnlp_2/lib/python3.11/site-packages/transformers/convert_slow_tokenizer.py:559: UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoConfig\n",
    "#from roberta import RobertaForSequenceClassification\n",
    "\n",
    "\n",
    "model_name = \"microsoft/deberta-v3-large\"\n",
    "\n",
    "#config.num_labels=2\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.padding_side = 'left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ed721fb1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 4439\n",
      "})\n",
      "Validation Dataset: Dataset({\n",
      "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
      "    num_rows: 495\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "from datasets import DatasetDict\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "\n",
    "def generate_prompt(data_point):\n",
    "    \"\"\"\n",
    "    Generates a prompt for evaluating the humor intensity of an edited headline.\n",
    "    Args:\n",
    "        data_point (dict): A dictionary containing 'original', 'edit', and 'meanGrade'.\n",
    "    Returns:\n",
    "        str: The formatted prompt as a string.\n",
    "    \"\"\"\n",
    "    return f\"\"\"# Sentence-1:: {data_point['sentence_A']}. # Sentence-2: {data_point['sentence_B']} # Output: The similarity is\"\"\"  # noqa: E501\n",
    "\n",
    "\n",
    "# Assuming `dataset` is your DatasetDict\n",
    "def add_label_column(example):\n",
    "\n",
    "    example['labels'] = float(example['relatedness_score'])\n",
    "  \n",
    "    example['input'] = generate_prompt(example)\n",
    "\n",
    "    \n",
    "    return example\n",
    "\n",
    "# Map the function over train and validation datasets\n",
    "\n",
    "train_data = raw_datasets['train'].map(add_label_column)\n",
    "val_data = raw_datasets['validation'].map(add_label_column)\n",
    "\n",
    "# Remove unnecessary columns\n",
    "\n",
    "# Inspect the updated datasets\n",
    "print(\"Train Dataset:\", train_data)\n",
    "print(\"Validation Dataset:\", val_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9e33204c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data['label'][10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a9fde6d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Map: 100%|██████████| 4439/4439 [00:00<00:00, 19704.84 examples/s]\n",
      "Map: 100%|██████████| 495/495 [00:00<00:00, 15073.08 examples/s]\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
    "\n",
    "\n",
    "tokenizer.padding_side = 'left'\n",
    "\n",
    "\n",
    "# col_to_delete = ['idx']\n",
    "col_to_delete =  ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'input']\n",
    "\n",
    "\n",
    "mask_token = tokenizer.mask_token\n",
    "def preprocessing_function(examples):\n",
    "   \n",
    "    return tokenizer(examples['input'], truncation=True, max_length=512)\n",
    "\n",
    "tokenized_train_data = train_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "tokenized_val_data = val_data.map(preprocessing_function, batched=True, remove_columns=col_to_delete)\n",
    "# llama_tokenized_datasets = llama_tokenized_datasets.rename_column(\"target\", \"label\")\n",
    "tokenized_train_data.set_format(\"torch\")\n",
    "tokenized_val_data.set_format(\"torch\")\n",
    "\n",
    "# Data collator for padding a batch of examples to the maximum length seen in the batch\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1931ed6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] # Sentence-1:: A person on a black motorbike is doing tricks with a jacket. # Sentence-2: A person is riding the bicycle on one wheel # Output: The similarity is[SEP]'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(tokenized_train_data['input_ids'][10])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abd6b985",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "25900f05",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'sentence_A', 'sentence_B', 'label', 'relatedness_score', 'entailment_AB', 'entailment_BA', 'sentence_A_original', 'sentence_B_original', 'sentence_A_dataset', 'sentence_B_dataset', 'labels', 'input'],\n",
       "    num_rows: 495\n",
       "})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1fdaa612",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "70"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_lengths = [len(ids) for ids in tokenized_train_data['input_ids']]\n",
    "mx = max(all_lengths)\n",
    "mx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d6618d0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "count = sum(len(ids) > 512 for ids in tokenized_train_data['input_ids'])\n",
    "print(count)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7a46cd19",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers import RobertaForSequenceClassification, DebertaV2ForMaskedLM\n",
    "from transformers.activations import ACT2FN\n",
    "import random\n",
    "\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "config = AutoConfig.from_pretrained(model_name)\n",
    "config.mask_token_id = tokenizer.mask_token_id\n",
    "config.num_labels = 1\n",
    "\n",
    "\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)\n",
    "#model = DebertaV2ForMaskedLM(config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "864ccb2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import RoCoFT\n",
    "\n",
    "RoCoFT.PEFT(model, method='column', rank=3) \n",
    "#targets=['key', 'value', 'dense', 'query'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bef34afd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n",
    "from scipy.stats import pearsonr, spearmanr\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    # If predictions are logits or have extra dimensions, squeeze\n",
    "    if predictions.ndim > 1:\n",
    "        predictions = predictions.squeeze()\n",
    "\n",
    "    mae = mean_absolute_error(labels, predictions)\n",
    "    mse = mean_squared_error(labels, predictions)\n",
    "    rmse = np.sqrt(mse)\n",
    "    r2 = r2_score(labels, predictions)\n",
    "    \n",
    "    # Define an \"accuracy\" for regression:\n",
    "    # Example: within some threshold tolerance\n",
    "    tolerance = 0.1  # you can change this\n",
    "    acc = np.mean(np.abs(predictions - labels) < tolerance)\n",
    "\n",
    "    pearson_corr, _ = pearsonr(predictions, labels)\n",
    "    spearman_corr, _ = spearmanr(predictions, labels)\n",
    "\n",
    "    return {\n",
    "        \"MAE\": mae,\n",
    "        \"MSE\": mse,\n",
    "        \"RMSE\": rmse,\n",
    "        \"Accuracy\": acc,\n",
    "        \"R2\": r2,\n",
    "        \"Pearson\": pearson_corr,\n",
    "        \"Spearman's Rank\": spearman_corr\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7dbcf96a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "import time\n",
    "from transformers import Trainer, TrainingArguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='dir',\n",
    "    learning_rate=1e-3,\n",
    "    per_device_train_batch_size=14,\n",
    "    per_device_eval_batch_size=64,\n",
    "    gradient_accumulation_steps= 1,\n",
    "    num_train_epochs=20,\n",
    "    weight_decay=0.20,\n",
    "    eval_strategy=\"steps\",\n",
    "    save_strategy=\"steps\",\n",
    "    save_total_limit=2,\n",
    "    save_steps=10000000,\n",
    "    logging_steps=100,\n",
    "   \n",
    "    load_best_model_at_end=True,\n",
    "    lr_scheduler_type=\"cosine\",  # You can choose from 'linear', 'cosine', 'cosine_with_restarts', 'polynomial', etc.\n",
    "    warmup_steps=100,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_train_data,\n",
    "    eval_dataset=tokenized_val_data,\n",
    "\n",
    "    data_collator=data_collator,\n",
    "    compute_metrics=compute_metrics\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "557cdbf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6360' max='6360' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6360/6360 30:16, Epoch 20/20]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Mae</th>\n",
       "      <th>Mse</th>\n",
       "      <th>Rmse</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>R2</th>\n",
       "      <th>Pearson</th>\n",
       "      <th>Spearman's rank</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>4.917400</td>\n",
       "      <td>0.635824</td>\n",
       "      <td>0.645116</td>\n",
       "      <td>0.635824</td>\n",
       "      <td>0.797386</td>\n",
       "      <td>0.098990</td>\n",
       "      <td>0.372697</td>\n",
       "      <td>0.819987</td>\n",
       "      <td>0.815980</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.436900</td>\n",
       "      <td>0.298303</td>\n",
       "      <td>0.419064</td>\n",
       "      <td>0.298303</td>\n",
       "      <td>0.546171</td>\n",
       "      <td>0.179798</td>\n",
       "      <td>0.705695</td>\n",
       "      <td>0.901795</td>\n",
       "      <td>0.869153</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.260100</td>\n",
       "      <td>0.224157</td>\n",
       "      <td>0.361352</td>\n",
       "      <td>0.224157</td>\n",
       "      <td>0.473453</td>\n",
       "      <td>0.214141</td>\n",
       "      <td>0.778847</td>\n",
       "      <td>0.913138</td>\n",
       "      <td>0.889451</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.191600</td>\n",
       "      <td>0.192075</td>\n",
       "      <td>0.342934</td>\n",
       "      <td>0.192075</td>\n",
       "      <td>0.438264</td>\n",
       "      <td>0.206061</td>\n",
       "      <td>0.810499</td>\n",
       "      <td>0.920818</td>\n",
       "      <td>0.891347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.176300</td>\n",
       "      <td>0.173343</td>\n",
       "      <td>0.316537</td>\n",
       "      <td>0.173343</td>\n",
       "      <td>0.416345</td>\n",
       "      <td>0.252525</td>\n",
       "      <td>0.828980</td>\n",
       "      <td>0.918278</td>\n",
       "      <td>0.893104</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.162700</td>\n",
       "      <td>0.159525</td>\n",
       "      <td>0.308380</td>\n",
       "      <td>0.159525</td>\n",
       "      <td>0.399405</td>\n",
       "      <td>0.244444</td>\n",
       "      <td>0.842613</td>\n",
       "      <td>0.925950</td>\n",
       "      <td>0.898975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.178400</td>\n",
       "      <td>0.179364</td>\n",
       "      <td>0.326750</td>\n",
       "      <td>0.179364</td>\n",
       "      <td>0.423514</td>\n",
       "      <td>0.202020</td>\n",
       "      <td>0.823040</td>\n",
       "      <td>0.919762</td>\n",
       "      <td>0.890681</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.155200</td>\n",
       "      <td>0.190727</td>\n",
       "      <td>0.345905</td>\n",
       "      <td>0.190727</td>\n",
       "      <td>0.436722</td>\n",
       "      <td>0.165657</td>\n",
       "      <td>0.811830</td>\n",
       "      <td>0.926734</td>\n",
       "      <td>0.901747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.147600</td>\n",
       "      <td>0.168431</td>\n",
       "      <td>0.309619</td>\n",
       "      <td>0.168431</td>\n",
       "      <td>0.410403</td>\n",
       "      <td>0.270707</td>\n",
       "      <td>0.833827</td>\n",
       "      <td>0.921360</td>\n",
       "      <td>0.895945</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.131900</td>\n",
       "      <td>0.237685</td>\n",
       "      <td>0.403622</td>\n",
       "      <td>0.237685</td>\n",
       "      <td>0.487529</td>\n",
       "      <td>0.123232</td>\n",
       "      <td>0.765501</td>\n",
       "      <td>0.924909</td>\n",
       "      <td>0.895372</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.123800</td>\n",
       "      <td>0.147135</td>\n",
       "      <td>0.292189</td>\n",
       "      <td>0.147135</td>\n",
       "      <td>0.383581</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.854837</td>\n",
       "      <td>0.930266</td>\n",
       "      <td>0.902528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.154100</td>\n",
       "      <td>0.244350</td>\n",
       "      <td>0.404923</td>\n",
       "      <td>0.244350</td>\n",
       "      <td>0.494318</td>\n",
       "      <td>0.139394</td>\n",
       "      <td>0.758924</td>\n",
       "      <td>0.928458</td>\n",
       "      <td>0.900330</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.134700</td>\n",
       "      <td>0.140614</td>\n",
       "      <td>0.283546</td>\n",
       "      <td>0.140614</td>\n",
       "      <td>0.374986</td>\n",
       "      <td>0.258586</td>\n",
       "      <td>0.861270</td>\n",
       "      <td>0.929136</td>\n",
       "      <td>0.902535</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.118100</td>\n",
       "      <td>0.275546</td>\n",
       "      <td>0.429520</td>\n",
       "      <td>0.275546</td>\n",
       "      <td>0.524925</td>\n",
       "      <td>0.109091</td>\n",
       "      <td>0.728147</td>\n",
       "      <td>0.928039</td>\n",
       "      <td>0.901670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.108200</td>\n",
       "      <td>0.147418</td>\n",
       "      <td>0.291472</td>\n",
       "      <td>0.147418</td>\n",
       "      <td>0.383951</td>\n",
       "      <td>0.254545</td>\n",
       "      <td>0.854557</td>\n",
       "      <td>0.931194</td>\n",
       "      <td>0.903304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1600</td>\n",
       "      <td>0.122300</td>\n",
       "      <td>0.175769</td>\n",
       "      <td>0.317797</td>\n",
       "      <td>0.175769</td>\n",
       "      <td>0.419248</td>\n",
       "      <td>0.216162</td>\n",
       "      <td>0.826587</td>\n",
       "      <td>0.930072</td>\n",
       "      <td>0.905045</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1700</td>\n",
       "      <td>0.103700</td>\n",
       "      <td>0.149956</td>\n",
       "      <td>0.292371</td>\n",
       "      <td>0.149956</td>\n",
       "      <td>0.387241</td>\n",
       "      <td>0.256566</td>\n",
       "      <td>0.852054</td>\n",
       "      <td>0.928620</td>\n",
       "      <td>0.902747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1800</td>\n",
       "      <td>0.113100</td>\n",
       "      <td>0.151774</td>\n",
       "      <td>0.300504</td>\n",
       "      <td>0.151774</td>\n",
       "      <td>0.389582</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.850260</td>\n",
       "      <td>0.926146</td>\n",
       "      <td>0.897174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1900</td>\n",
       "      <td>0.112500</td>\n",
       "      <td>0.176348</td>\n",
       "      <td>0.320711</td>\n",
       "      <td>0.176348</td>\n",
       "      <td>0.419939</td>\n",
       "      <td>0.236364</td>\n",
       "      <td>0.826015</td>\n",
       "      <td>0.927960</td>\n",
       "      <td>0.900024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2000</td>\n",
       "      <td>0.095000</td>\n",
       "      <td>0.155098</td>\n",
       "      <td>0.294825</td>\n",
       "      <td>0.155098</td>\n",
       "      <td>0.393825</td>\n",
       "      <td>0.270707</td>\n",
       "      <td>0.846980</td>\n",
       "      <td>0.927588</td>\n",
       "      <td>0.902574</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2100</td>\n",
       "      <td>0.102900</td>\n",
       "      <td>0.147959</td>\n",
       "      <td>0.295345</td>\n",
       "      <td>0.147959</td>\n",
       "      <td>0.384655</td>\n",
       "      <td>0.218182</td>\n",
       "      <td>0.854023</td>\n",
       "      <td>0.929358</td>\n",
       "      <td>0.904430</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2200</td>\n",
       "      <td>0.105800</td>\n",
       "      <td>0.193425</td>\n",
       "      <td>0.348361</td>\n",
       "      <td>0.193425</td>\n",
       "      <td>0.439801</td>\n",
       "      <td>0.155556</td>\n",
       "      <td>0.809168</td>\n",
       "      <td>0.924740</td>\n",
       "      <td>0.899824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2300</td>\n",
       "      <td>0.093400</td>\n",
       "      <td>0.159523</td>\n",
       "      <td>0.302839</td>\n",
       "      <td>0.159523</td>\n",
       "      <td>0.399403</td>\n",
       "      <td>0.224242</td>\n",
       "      <td>0.842615</td>\n",
       "      <td>0.928153</td>\n",
       "      <td>0.899901</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2400</td>\n",
       "      <td>0.088400</td>\n",
       "      <td>0.144712</td>\n",
       "      <td>0.281538</td>\n",
       "      <td>0.144712</td>\n",
       "      <td>0.380410</td>\n",
       "      <td>0.270707</td>\n",
       "      <td>0.857228</td>\n",
       "      <td>0.928516</td>\n",
       "      <td>0.902445</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2500</td>\n",
       "      <td>0.101300</td>\n",
       "      <td>0.172362</td>\n",
       "      <td>0.310078</td>\n",
       "      <td>0.172362</td>\n",
       "      <td>0.415166</td>\n",
       "      <td>0.256566</td>\n",
       "      <td>0.829948</td>\n",
       "      <td>0.924896</td>\n",
       "      <td>0.894801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2600</td>\n",
       "      <td>0.096600</td>\n",
       "      <td>0.157156</td>\n",
       "      <td>0.305266</td>\n",
       "      <td>0.157156</td>\n",
       "      <td>0.396429</td>\n",
       "      <td>0.220202</td>\n",
       "      <td>0.844951</td>\n",
       "      <td>0.928683</td>\n",
       "      <td>0.897878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2700</td>\n",
       "      <td>0.077400</td>\n",
       "      <td>0.180577</td>\n",
       "      <td>0.328549</td>\n",
       "      <td>0.180577</td>\n",
       "      <td>0.424944</td>\n",
       "      <td>0.185859</td>\n",
       "      <td>0.821843</td>\n",
       "      <td>0.927247</td>\n",
       "      <td>0.897867</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2800</td>\n",
       "      <td>0.087900</td>\n",
       "      <td>0.149698</td>\n",
       "      <td>0.290599</td>\n",
       "      <td>0.149698</td>\n",
       "      <td>0.386908</td>\n",
       "      <td>0.262626</td>\n",
       "      <td>0.852309</td>\n",
       "      <td>0.927887</td>\n",
       "      <td>0.898100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2900</td>\n",
       "      <td>0.085600</td>\n",
       "      <td>0.224526</td>\n",
       "      <td>0.369697</td>\n",
       "      <td>0.224526</td>\n",
       "      <td>0.473841</td>\n",
       "      <td>0.137374</td>\n",
       "      <td>0.778484</td>\n",
       "      <td>0.924841</td>\n",
       "      <td>0.899385</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3000</td>\n",
       "      <td>0.087600</td>\n",
       "      <td>0.174229</td>\n",
       "      <td>0.313649</td>\n",
       "      <td>0.174229</td>\n",
       "      <td>0.417407</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.828106</td>\n",
       "      <td>0.923320</td>\n",
       "      <td>0.892304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3100</td>\n",
       "      <td>0.085500</td>\n",
       "      <td>0.180391</td>\n",
       "      <td>0.330034</td>\n",
       "      <td>0.180391</td>\n",
       "      <td>0.424724</td>\n",
       "      <td>0.187879</td>\n",
       "      <td>0.822027</td>\n",
       "      <td>0.928767</td>\n",
       "      <td>0.896099</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3200</td>\n",
       "      <td>0.075600</td>\n",
       "      <td>0.151974</td>\n",
       "      <td>0.294822</td>\n",
       "      <td>0.151974</td>\n",
       "      <td>0.389838</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>0.850063</td>\n",
       "      <td>0.929377</td>\n",
       "      <td>0.897832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3300</td>\n",
       "      <td>0.076200</td>\n",
       "      <td>0.163116</td>\n",
       "      <td>0.317341</td>\n",
       "      <td>0.163116</td>\n",
       "      <td>0.403876</td>\n",
       "      <td>0.189899</td>\n",
       "      <td>0.839070</td>\n",
       "      <td>0.925954</td>\n",
       "      <td>0.892780</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3400</td>\n",
       "      <td>0.080100</td>\n",
       "      <td>0.158841</td>\n",
       "      <td>0.309157</td>\n",
       "      <td>0.158841</td>\n",
       "      <td>0.398548</td>\n",
       "      <td>0.197980</td>\n",
       "      <td>0.843288</td>\n",
       "      <td>0.929622</td>\n",
       "      <td>0.898873</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3500</td>\n",
       "      <td>0.075200</td>\n",
       "      <td>0.154141</td>\n",
       "      <td>0.302981</td>\n",
       "      <td>0.154141</td>\n",
       "      <td>0.392607</td>\n",
       "      <td>0.208081</td>\n",
       "      <td>0.847925</td>\n",
       "      <td>0.928646</td>\n",
       "      <td>0.903221</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3600</td>\n",
       "      <td>0.069200</td>\n",
       "      <td>0.168778</td>\n",
       "      <td>0.317207</td>\n",
       "      <td>0.168778</td>\n",
       "      <td>0.410826</td>\n",
       "      <td>0.202020</td>\n",
       "      <td>0.833484</td>\n",
       "      <td>0.926922</td>\n",
       "      <td>0.896015</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3700</td>\n",
       "      <td>0.069100</td>\n",
       "      <td>0.162452</td>\n",
       "      <td>0.299089</td>\n",
       "      <td>0.162452</td>\n",
       "      <td>0.403053</td>\n",
       "      <td>0.274747</td>\n",
       "      <td>0.839725</td>\n",
       "      <td>0.927053</td>\n",
       "      <td>0.898308</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3800</td>\n",
       "      <td>0.069300</td>\n",
       "      <td>0.182072</td>\n",
       "      <td>0.327250</td>\n",
       "      <td>0.182072</td>\n",
       "      <td>0.426699</td>\n",
       "      <td>0.206061</td>\n",
       "      <td>0.820368</td>\n",
       "      <td>0.926180</td>\n",
       "      <td>0.893159</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3900</td>\n",
       "      <td>0.064200</td>\n",
       "      <td>0.173997</td>\n",
       "      <td>0.320064</td>\n",
       "      <td>0.173997</td>\n",
       "      <td>0.417129</td>\n",
       "      <td>0.208081</td>\n",
       "      <td>0.828335</td>\n",
       "      <td>0.926561</td>\n",
       "      <td>0.894977</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4000</td>\n",
       "      <td>0.067300</td>\n",
       "      <td>0.149395</td>\n",
       "      <td>0.295847</td>\n",
       "      <td>0.149395</td>\n",
       "      <td>0.386517</td>\n",
       "      <td>0.246465</td>\n",
       "      <td>0.852607</td>\n",
       "      <td>0.928324</td>\n",
       "      <td>0.897567</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4100</td>\n",
       "      <td>0.068700</td>\n",
       "      <td>0.142321</td>\n",
       "      <td>0.281090</td>\n",
       "      <td>0.142321</td>\n",
       "      <td>0.377254</td>\n",
       "      <td>0.305051</td>\n",
       "      <td>0.859586</td>\n",
       "      <td>0.930397</td>\n",
       "      <td>0.899337</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4200</td>\n",
       "      <td>0.062600</td>\n",
       "      <td>0.145620</td>\n",
       "      <td>0.283303</td>\n",
       "      <td>0.145620</td>\n",
       "      <td>0.381602</td>\n",
       "      <td>0.284848</td>\n",
       "      <td>0.856331</td>\n",
       "      <td>0.929262</td>\n",
       "      <td>0.897181</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4300</td>\n",
       "      <td>0.062800</td>\n",
       "      <td>0.154560</td>\n",
       "      <td>0.294825</td>\n",
       "      <td>0.154560</td>\n",
       "      <td>0.393141</td>\n",
       "      <td>0.266667</td>\n",
       "      <td>0.847512</td>\n",
       "      <td>0.928661</td>\n",
       "      <td>0.899041</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4400</td>\n",
       "      <td>0.059100</td>\n",
       "      <td>0.154009</td>\n",
       "      <td>0.303061</td>\n",
       "      <td>0.154009</td>\n",
       "      <td>0.392440</td>\n",
       "      <td>0.212121</td>\n",
       "      <td>0.848055</td>\n",
       "      <td>0.927926</td>\n",
       "      <td>0.897917</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4500</td>\n",
       "      <td>0.060700</td>\n",
       "      <td>0.147396</td>\n",
       "      <td>0.283027</td>\n",
       "      <td>0.147396</td>\n",
       "      <td>0.383922</td>\n",
       "      <td>0.298990</td>\n",
       "      <td>0.854579</td>\n",
       "      <td>0.927750</td>\n",
       "      <td>0.896381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4600</td>\n",
       "      <td>0.058000</td>\n",
       "      <td>0.165229</td>\n",
       "      <td>0.304417</td>\n",
       "      <td>0.165229</td>\n",
       "      <td>0.406483</td>\n",
       "      <td>0.252525</td>\n",
       "      <td>0.836986</td>\n",
       "      <td>0.926419</td>\n",
       "      <td>0.895137</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4700</td>\n",
       "      <td>0.053800</td>\n",
       "      <td>0.172587</td>\n",
       "      <td>0.315699</td>\n",
       "      <td>0.172587</td>\n",
       "      <td>0.415435</td>\n",
       "      <td>0.218182</td>\n",
       "      <td>0.829726</td>\n",
       "      <td>0.925502</td>\n",
       "      <td>0.892066</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4800</td>\n",
       "      <td>0.055900</td>\n",
       "      <td>0.160728</td>\n",
       "      <td>0.303209</td>\n",
       "      <td>0.160728</td>\n",
       "      <td>0.400909</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.841426</td>\n",
       "      <td>0.925634</td>\n",
       "      <td>0.892730</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4900</td>\n",
       "      <td>0.056000</td>\n",
       "      <td>0.160889</td>\n",
       "      <td>0.296630</td>\n",
       "      <td>0.160889</td>\n",
       "      <td>0.401109</td>\n",
       "      <td>0.284848</td>\n",
       "      <td>0.841268</td>\n",
       "      <td>0.926322</td>\n",
       "      <td>0.896379</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>0.051800</td>\n",
       "      <td>0.157368</td>\n",
       "      <td>0.299634</td>\n",
       "      <td>0.157368</td>\n",
       "      <td>0.396697</td>\n",
       "      <td>0.238384</td>\n",
       "      <td>0.844741</td>\n",
       "      <td>0.926708</td>\n",
       "      <td>0.896471</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5100</td>\n",
       "      <td>0.054100</td>\n",
       "      <td>0.163847</td>\n",
       "      <td>0.307303</td>\n",
       "      <td>0.163847</td>\n",
       "      <td>0.404781</td>\n",
       "      <td>0.228283</td>\n",
       "      <td>0.838349</td>\n",
       "      <td>0.925464</td>\n",
       "      <td>0.892502</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5200</td>\n",
       "      <td>0.050200</td>\n",
       "      <td>0.158387</td>\n",
       "      <td>0.296507</td>\n",
       "      <td>0.158387</td>\n",
       "      <td>0.397978</td>\n",
       "      <td>0.260606</td>\n",
       "      <td>0.843736</td>\n",
       "      <td>0.925112</td>\n",
       "      <td>0.890048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5300</td>\n",
       "      <td>0.054100</td>\n",
       "      <td>0.171898</td>\n",
       "      <td>0.309298</td>\n",
       "      <td>0.171898</td>\n",
       "      <td>0.414606</td>\n",
       "      <td>0.232323</td>\n",
       "      <td>0.830405</td>\n",
       "      <td>0.925199</td>\n",
       "      <td>0.893215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5400</td>\n",
       "      <td>0.051400</td>\n",
       "      <td>0.158943</td>\n",
       "      <td>0.296321</td>\n",
       "      <td>0.158943</td>\n",
       "      <td>0.398677</td>\n",
       "      <td>0.274747</td>\n",
       "      <td>0.843187</td>\n",
       "      <td>0.926546</td>\n",
       "      <td>0.894952</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5500</td>\n",
       "      <td>0.051700</td>\n",
       "      <td>0.163868</td>\n",
       "      <td>0.304840</td>\n",
       "      <td>0.163868</td>\n",
       "      <td>0.404806</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.838328</td>\n",
       "      <td>0.926855</td>\n",
       "      <td>0.894094</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5600</td>\n",
       "      <td>0.050300</td>\n",
       "      <td>0.158556</td>\n",
       "      <td>0.298869</td>\n",
       "      <td>0.158556</td>\n",
       "      <td>0.398191</td>\n",
       "      <td>0.248485</td>\n",
       "      <td>0.843569</td>\n",
       "      <td>0.925873</td>\n",
       "      <td>0.893479</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5700</td>\n",
       "      <td>0.052700</td>\n",
       "      <td>0.161747</td>\n",
       "      <td>0.302273</td>\n",
       "      <td>0.161747</td>\n",
       "      <td>0.402178</td>\n",
       "      <td>0.238384</td>\n",
       "      <td>0.840421</td>\n",
       "      <td>0.926171</td>\n",
       "      <td>0.894233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5800</td>\n",
       "      <td>0.046200</td>\n",
       "      <td>0.165151</td>\n",
       "      <td>0.305089</td>\n",
       "      <td>0.165151</td>\n",
       "      <td>0.406388</td>\n",
       "      <td>0.234343</td>\n",
       "      <td>0.837062</td>\n",
       "      <td>0.926121</td>\n",
       "      <td>0.894689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5900</td>\n",
       "      <td>0.051600</td>\n",
       "      <td>0.160661</td>\n",
       "      <td>0.299216</td>\n",
       "      <td>0.160661</td>\n",
       "      <td>0.400825</td>\n",
       "      <td>0.262626</td>\n",
       "      <td>0.841493</td>\n",
       "      <td>0.925936</td>\n",
       "      <td>0.893627</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6000</td>\n",
       "      <td>0.049500</td>\n",
       "      <td>0.157904</td>\n",
       "      <td>0.296783</td>\n",
       "      <td>0.157904</td>\n",
       "      <td>0.397371</td>\n",
       "      <td>0.262626</td>\n",
       "      <td>0.844212</td>\n",
       "      <td>0.925857</td>\n",
       "      <td>0.893306</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6100</td>\n",
       "      <td>0.048200</td>\n",
       "      <td>0.161413</td>\n",
       "      <td>0.300525</td>\n",
       "      <td>0.161413</td>\n",
       "      <td>0.401762</td>\n",
       "      <td>0.254545</td>\n",
       "      <td>0.840751</td>\n",
       "      <td>0.925834</td>\n",
       "      <td>0.893199</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6200</td>\n",
       "      <td>0.048100</td>\n",
       "      <td>0.161552</td>\n",
       "      <td>0.300872</td>\n",
       "      <td>0.161552</td>\n",
       "      <td>0.401935</td>\n",
       "      <td>0.254545</td>\n",
       "      <td>0.840613</td>\n",
       "      <td>0.925777</td>\n",
       "      <td>0.893088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6300</td>\n",
       "      <td>0.047600</td>\n",
       "      <td>0.162197</td>\n",
       "      <td>0.301699</td>\n",
       "      <td>0.162197</td>\n",
       "      <td>0.402736</td>\n",
       "      <td>0.248485</td>\n",
       "      <td>0.839977</td>\n",
       "      <td>0.925811</td>\n",
       "      <td>0.893249</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=6360, training_loss=0.171045786854606, metrics={'train_runtime': 1817.0606, 'train_samples_per_second': 48.859, 'train_steps_per_second': 3.5, 'total_flos': 28115183972304.0, 'train_loss': 0.171045786854606, 'epoch': 20.0})"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "emnlp_2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
